diff --git a/.vscode/settings.json b/.vscode/settings.json index be38d2c6..fda9355b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,7 @@ { "cSpell.words": [ + "INTERNLM", + "QWEN", "vidur" ] } \ No newline at end of file diff --git a/data/device_configs/a100.yml b/data/device_configs/a100.yml deleted file mode 100644 index b2ba1700..00000000 --- a/data/device_configs/a100.yml +++ /dev/null @@ -1,4 +0,0 @@ -fp16_tflops: 312 -total_memory_gb: 80 -num_devices_per_node: 4 - diff --git a/data/device_configs/a40.yml b/data/device_configs/a40.yml deleted file mode 100644 index 17e0446a..00000000 --- a/data/device_configs/a40.yml +++ /dev/null @@ -1,3 +0,0 @@ -fp16_tflops: 150 -total_memory_gb: 45 -num_devices_per_node: 8 diff --git a/data/device_configs/h100.yml b/data/device_configs/h100.yml deleted file mode 100644 index cd174752..00000000 --- a/data/device_configs/h100.yml +++ /dev/null @@ -1,3 +0,0 @@ -fp16_tflops: 1000 -total_memory_gb: 80 -num_devices_per_node: 4 diff --git a/data/model_configs/Qwen/Qwen-72B.yml b/data/model_configs/Qwen/Qwen-72B.yml deleted file mode 100644 index 0acd50d2..00000000 --- a/data/model_configs/Qwen/Qwen-72B.yml +++ /dev/null @@ -1,15 +0,0 @@ -num_layers: 80 -num_q_heads: 64 -num_kv_heads: 64 -embedding_dim: 8192 -mlp_hidden_dim: 24576 -max_position_embeddings: 32768 -use_gated_mlp: true -use_bias: false -use_qkv_bias: true -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 1000000 -vocab_size: 152064 -is_neox_style: true \ No newline at end of file diff --git a/data/model_configs/codellama/CodeLlama-34b-Instruct-hf.yml b/data/model_configs/codellama/CodeLlama-34b-Instruct-hf.yml deleted file mode 100644 index e07dde4a..00000000 --- a/data/model_configs/codellama/CodeLlama-34b-Instruct-hf.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 48 -num_q_heads: 64 -num_kv_heads: 8 -embedding_dim: 8192 -mlp_hidden_dim: 22016 -max_position_embeddings: 16384 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_scaling: null -rope_theta: 1000000 -vocab_size: 32768 -is_neox_style: true \ No newline at end of file diff --git a/data/model_configs/internlm/internlm-20b.yml b/data/model_configs/internlm/internlm-20b.yml deleted file mode 100644 index d93b467b..00000000 --- a/data/model_configs/internlm/internlm-20b.yml +++ /dev/null @@ -1,15 +0,0 @@ -num_layers: 60 -num_q_heads: 40 -num_kv_heads: 40 -embedding_dim: 5120 -mlp_hidden_dim: 13824 -max_position_embeddings: 4096 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -rope_scaling: null -rope_theta: 10000 -post_attn_norm: true -vocab_size: 103168 diff --git a/data/model_configs/internlm/internlm2-20b.yml b/data/model_configs/internlm/internlm2-20b.yml deleted file mode 100644 index f8a0337f..00000000 --- a/data/model_configs/internlm/internlm2-20b.yml +++ /dev/null @@ -1,15 +0,0 @@ -num_layers: 48 -num_q_heads: 48 -num_kv_heads: 8 -embedding_dim: 6144 -mlp_hidden_dim: 16384 -max_position_embeddings: 32768 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -act: silu -norm: rms_norm -post_attn_norm: true -rope_scaling: null -rope_theta: 1000000 -vocab_size: 92544 diff --git a/data/model_configs/meta-llama/Llama-2-70b-hf.yml b/data/model_configs/meta-llama/Llama-2-70b-hf.yml deleted file mode 100644 index f74d8445..00000000 --- a/data/model_configs/meta-llama/Llama-2-70b-hf.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 80 -num_q_heads: 64 -num_kv_heads: 8 -embedding_dim: 8192 -mlp_hidden_dim: 28672 -max_position_embeddings: 4096 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 10000.0 -rope_scaling: null -vocab_size: 32768 -is_neox_style: true \ No newline at end of file diff --git a/data/model_configs/meta-llama/Llama-2-7b-hf.yml b/data/model_configs/meta-llama/Llama-2-7b-hf.yml deleted file mode 100644 index a49b2bbc..00000000 --- a/data/model_configs/meta-llama/Llama-2-7b-hf.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 32 -num_q_heads: 32 -num_kv_heads: 32 -embedding_dim: 4096 -mlp_hidden_dim: 11008 -max_position_embeddings: 4096 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 10000.0 -rope_scaling: null -vocab_size: 32768 -is_neox_style: true diff --git a/data/model_configs/meta-llama/Meta-Llama-3-70B.yml b/data/model_configs/meta-llama/Meta-Llama-3-70B.yml deleted file mode 100644 index eff626a6..00000000 --- a/data/model_configs/meta-llama/Meta-Llama-3-70B.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 80 -num_q_heads: 64 -num_kv_heads: 8 -embedding_dim: 8192 -mlp_hidden_dim: 28672 -max_position_embeddings: 8192 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 500000.0 -rope_scaling: null -vocab_size: 128256 -is_neox_style: true diff --git a/data/model_configs/meta-llama/Meta-Llama-3-8B.yml b/data/model_configs/meta-llama/Meta-Llama-3-8B.yml deleted file mode 100644 index e4bba4cc..00000000 --- a/data/model_configs/meta-llama/Meta-Llama-3-8B.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 32 -num_q_heads: 32 -num_kv_heads: 8 -embedding_dim: 4096 -mlp_hidden_dim: 14336 -max_position_embeddings: 4096 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 500000.0 -rope_scaling: null -vocab_size: 128256 -is_neox_style: true diff --git a/data/model_configs/microsoft/phi-2.yml b/data/model_configs/microsoft/phi-2.yml deleted file mode 100644 index 77c7776d..00000000 --- a/data/model_configs/microsoft/phi-2.yml +++ /dev/null @@ -1,17 +0,0 @@ -num_layers: 32 -num_q_heads: 32 -num_kv_heads: 32 -embedding_dim: 2560 -mlp_hidden_dim: 10240 -max_position_embeddings: 2048 -use_gated_mlp: false -use_bias: true -use_qkv_bias: true -activation: gelu -norm: layer_norm -post_attn_norm: false -vocab_size: 51200 -rope_scaling: null -rope_theta: 10000.0 -partial_rotary_factor: 0.4 -no_tensor_parallel: true diff --git a/data/model_configs/openai/gpt3.yml b/data/model_configs/openai/gpt3.yml deleted file mode 100644 index a34b2b3e..00000000 --- a/data/model_configs/openai/gpt3.yml +++ /dev/null @@ -1,7 +0,0 @@ -num_layers: 96 -num_q_heads: 96 -num_kv_heads: 96 -embedding_dim: 12288 -mlp_hidden_dim: 49152 -use_gated_mlp: false -vocab_size: 50257 diff --git a/data/model_configs/tiiuae/falcon-180B.yml b/data/model_configs/tiiuae/falcon-180B.yml deleted file mode 100644 index 4b8cdb4d..00000000 --- a/data/model_configs/tiiuae/falcon-180B.yml +++ /dev/null @@ -1,8 +0,0 @@ -num_layers: 80 -num_q_heads: 232 -num_kv_heads: 8 -embedding_dim: 14848 -mlp_hidden_dim: 59392 -use_gated_mlp: false -vocab_size: 65024 -is_neox_style: true \ No newline at end of file diff --git a/data/profiling/network/a100_pair_nvlink/all_reduce.csv b/data/profiling/network/a100_pairwise_nvlink/all_reduce.csv similarity index 100% rename from data/profiling/network/a100_pair_nvlink/all_reduce.csv rename to data/profiling/network/a100_pairwise_nvlink/all_reduce.csv diff --git a/data/profiling/network/a100_pair_nvlink/send_recv.csv b/data/profiling/network/a100_pairwise_nvlink/send_recv.csv similarity index 100% rename from data/profiling/network/a100_pair_nvlink/send_recv.csv rename to data/profiling/network/a100_pairwise_nvlink/send_recv.csv diff --git a/data/profiling/network/a40_pair_nvlink/attention.csv b/data/profiling/network/a40_pairwise_nvlink/attention.csv similarity index 100% rename from data/profiling/network/a40_pair_nvlink/attention.csv rename to data/profiling/network/a40_pairwise_nvlink/attention.csv diff --git a/data/profiling/network/a40_pair_nvlink/send_recv.csv b/data/profiling/network/a40_pairwise_nvlink/send_recv.csv similarity index 100% rename from data/profiling/network/a40_pair_nvlink/send_recv.csv rename to data/profiling/network/a40_pairwise_nvlink/send_recv.csv diff --git a/data/profiling/network/h100_pair_nvlink/all_reduce.csv b/data/profiling/network/h100_pairwise_nvlink/all_reduce.csv similarity index 100% rename from data/profiling/network/h100_pair_nvlink/all_reduce.csv rename to data/profiling/network/h100_pairwise_nvlink/all_reduce.csv diff --git a/data/profiling/network/h100_pair_nvlink/send_recv.csv b/data/profiling/network/h100_pairwise_nvlink/send_recv.csv similarity index 100% rename from data/profiling/network/h100_pair_nvlink/send_recv.csv rename to data/profiling/network/h100_pairwise_nvlink/send_recv.csv diff --git a/vidur/config/__init__.py b/vidur/config/__init__.py index 83c25e22..27c9ec62 100644 --- a/vidur/config/__init__.py +++ b/vidur/config/__init__.py @@ -1,3 +1 @@ -from vidur.config.config import Config - -__all__ = [Config] +from .config import * diff --git a/vidur/config/base_fixed_config.py b/vidur/config/base_fixed_config.py new file mode 100644 index 00000000..2d469b39 --- /dev/null +++ b/vidur/config/base_fixed_config.py @@ -0,0 +1,30 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Any + +from vidur.config.utils import get_all_subclasses + + +@dataclass +class BaseFixedConfig(ABC): + + @classmethod + def create_from_type(cls, type_: Any) -> Any: + for subclass in get_all_subclasses(cls): + if subclass.get_type() == type_: + return subclass() + raise ValueError(f"[{cls.__name__}] Invalid type: {type_}") + + @classmethod + def create_from_name(cls, name: str) -> Any: + for subclass in get_all_subclasses(cls): + if subclass.get_name() == name: + return subclass() + raise ValueError(f"[{cls.__name__}] Invalid name: {name}") + + @classmethod + def create_from_type_string(cls, type_str: str) -> Any: + for subclass in get_all_subclasses(cls): + if str(subclass.get_type()) == type_str: + return subclass() + raise ValueError(f"[{cls.__name__}] Invalid type string: {type_str}") diff --git a/vidur/config/base_poly_config.py b/vidur/config/base_poly_config.py new file mode 100644 index 00000000..fbdd0e1f --- /dev/null +++ b/vidur/config/base_poly_config.py @@ -0,0 +1,16 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Any + +from vidur.config.utils import get_all_subclasses + + +@dataclass +class BasePolyConfig(ABC): + + @classmethod + def create_from_type(cls, type_: Any) -> Any: + for subclass in get_all_subclasses(cls): + if subclass.get_type() == type_: + return subclass() + raise ValueError(f"Invalid type: {type_}") diff --git a/vidur/config/config.py b/vidur/config/config.py index 5a162030..70ac5dab 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -1,109 +1,713 @@ -import argparse -import datetime +import json import os +from abc import ABC +from dataclasses import dataclass, field +from datetime import datetime +from typing import List, Optional -import yaml - -from vidur.constants import DEFAULT_CONFIG_FILE, DEVICE_CONFIG_DIR, MODEL_CONFIG_DIR +from vidur.config.base_poly_config import BasePolyConfig +from vidur.config.device_sku_config import BaseDeviceSKUConfig +from vidur.config.flat_dataclass import create_flat_dataclass +from vidur.config.model_config import BaseModelConfig +from vidur.config.node_sku_config import BaseNodeSKUConfig +from vidur.config.utils import dataclass_to_dict from vidur.logger import init_logger +from vidur.types import ( + ExecutionTimePredictorType, + GlobalSchedulerType, + ReplicaSchedulerType, + RequestGeneratorType, + RequestIntervalGeneratorType, + RequestLengthGeneratorType, +) logger = init_logger(__name__) -class Config: - def __init__(self, config_file=DEFAULT_CONFIG_FILE): - self._parser = argparse.ArgumentParser() - self._args = None - self._load_yaml(config_file) - self._parse_args() - self._add_derived_args() - self._write_yaml_to_file() - logger.info(f"Config: {self.get_yaml()}") - - def _load_yaml(self, filename): - with open(filename, "r") as file: - yaml_config = yaml.safe_load(file) - self._update_namespace(yaml_config) - - def _parse_args(self): - self._args = self._parser.parse_args() - - def _add_derived_args(self): - self._args.output_dir = f"{self._args.output_dir}/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}" - os.makedirs(self._args.output_dir, exist_ok=True) - self._load_model_config() - self._load_device_config() - self._substitute_variables_in_args() - - def _update_namespace(self, config_dict, parent_key=""): - for key, value in config_dict.items(): - if isinstance(value, dict): - new_key = f"{parent_key}{key}_" if parent_key else f"{key}_" - self._update_namespace(value, new_key) - else: - arg_name = f"{parent_key}{key}" - - if type(value) == bool: - self._parser.add_argument( - f"--{arg_name}", - default=value, - action=argparse.BooleanOptionalAction, - ) - elif arg_name in [ - "simulator_time_limit", - "metrics_store_subsamples", - "replica_scheduler_num_blocks", - ]: - self._parser.add_argument(f"--{arg_name}", default=value, type=int) - else: - self._parser.add_argument( - f"--{arg_name}", default=value, type=type(value) - ) - - def __getattr__(self, name): - return getattr(self._args, name, None) - - def get_yaml(self): - return yaml.dump(self._args.__dict__, default_flow_style=False) - - def _write_yaml_to_file(self): - with open(f"{self._args.output_dir}/config.yml", "w") as file: - file.write(self.get_yaml()) +@dataclass +class BaseRequestIntervalGeneratorConfig(BasePolyConfig): + seed: int = field( + default=42, + metadata={"help": "Seed for the random number generator."}, + ) + + +@dataclass +class BaseRequestLengthGeneratorConfig(BasePolyConfig): + seed: int = field( + default=42, + metadata={"help": "Seed for the random number generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens."}, + ) + + +@dataclass +class TraceRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): + trace_file: str = field( + default="data/processed_traces/AzureFunctionsInvocationTraceForTwoWeeksJan2021Processed.csv", + metadata={"help": "Path to the trace request interval generator file."}, + ) + start_time: str = field( + default="1970-01-04 12:00:00", + metadata={"help": "Start time of the trace request interval generator."}, + ) + end_time: str = field( + default="1970-01-04 15:00:00", + metadata={"help": "End time of the trace request interval generator."}, + ) + time_scale_factor: float = field( + default=0.3, + metadata={ + "help": "Time scale factor for the trace request interval generator." + }, + ) + + @staticmethod + def get_type(): + return RequestIntervalGeneratorType.TRACE + + +@dataclass +class PoissonRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): + qps: float = field( + default=0.5, + metadata={"help": "Queries per second for Poisson Request Interval Generator."}, + ) + + @staticmethod + def get_type(): + return RequestIntervalGeneratorType.POISSON + + +@dataclass +class GammaRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): + qps: float = field( + default=0.2, + metadata={"help": "Queries per second for Gamma Request Interval Generator."}, + ) + cv: float = field( + default=0.5, + metadata={ + "help": "Coefficient of variation for Gamma Request Interval Generator." + }, + ) + + @staticmethod + def get_type(): + return RequestIntervalGeneratorType.GAMMA + + +@dataclass +class StaticRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): + @staticmethod + def get_type(): + return RequestIntervalGeneratorType.STATIC + + +@dataclass +class TraceRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + trace_file: str = field( + default="data/processed_traces/sharegpt_8k_filtered_stats_llama2_tokenizer.csv", + metadata={"help": "Path to the trace request length generator file."}, + ) + prefill_scale_factor: float = field( + default=1, + metadata={ + "help": "Prefill scale factor for the trace request length generator." + }, + ) + decode_scale_factor: float = field( + default=1, + metadata={ + "help": "Decode scale factor for the trace request length generator." + }, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens for the trace request length generator."}, + ) + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.TRACE + + +@dataclass +class ZipfRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + theta: float = field( + default=0.6, + metadata={"help": "Theta for Zipf Request Length Generator."}, + ) + scramble: bool = field( + default=False, + metadata={"help": "Scramble for Zipf Request Length Generator."}, + ) + min_tokens: int = field( + default=1024, + metadata={"help": "Minimum tokens for Zipf Request Length Generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens for Zipf Request Length Generator."}, + ) + prefill_to_decode_ratio: float = field( + default=20.0, + metadata={"help": "Prefill to decode ratio for Zipf Request Length Generator."}, + ) + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.ZIPF + + +@dataclass +class UniformRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + min_tokens: int = field( + default=1024, + metadata={"help": "Minimum tokens for Uniform Request Length Generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens for Uniform Request Length Generator."}, + ) + prefill_to_decode_ratio: float = field( + default=20.0, + metadata={ + "help": "Prefill to decode ratio for Uniform Request Length Generator." + }, + ) + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.UNIFORM + + +@dataclass +class FixedRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + prefill_tokens: int = field( + default=2048, + metadata={"help": "Prefill tokens for Fixed Request Length Generator."}, + ) + decode_tokens: int = field( + default=512, + metadata={"help": "Decode tokens for Fixed Request Length Generator."}, + ) + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.FIXED + + +@dataclass +class BaseRequestGeneratorConfig(BasePolyConfig): + seed: int = field( + default=42, + metadata={"help": "Seed for the random number generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens."}, + ) + + +@dataclass +class SyntheticRequestGeneratorConfig(BaseRequestGeneratorConfig): + length_generator_config: BaseRequestLengthGeneratorConfig = field( + default_factory=FixedRequestLengthGeneratorConfig, + metadata={"help": "Length generator config for Synthetic Request Generator."}, + ) + interval_generator_config: BaseRequestIntervalGeneratorConfig = field( + default_factory=PoissonRequestIntervalGeneratorConfig, + metadata={"help": "Interval generator config for Synthetic Request Generator."}, + ) + num_requests: int = field( + default=128, + metadata={"help": "Number of requests for Synthetic Request Generator."}, + ) + duration: float = field( + default=None, + metadata={"help": "Duration of the synthetic request generator."}, + ) + + def __post_init__(self): + self.max_tokens = self.length_generator_config.max_tokens + + @staticmethod + def get_type(): + return RequestGeneratorType.SYNTHETIC + + +@dataclass +class TraceRequestGeneratorConfig(BaseRequestGeneratorConfig): + trace_file: str = field( + default="data/processed_traces/sydney_enterprise.csv", + metadata={"help": "Path to the trace request generator file."}, + ) + date: str = field( + default="2023-08-21", + metadata={"help": "Date for the trace request generator."}, + ) + prefill_scale_factor: float = field( + default=0.3, + metadata={"help": "Prefill scale factor for the trace request generator."}, + ) + decode_scale_factor: float = field( + default=1, + metadata={"help": "Decode scale factor for the trace request generator."}, + ) + time_scale_factor: float = field( + default=0.04, + metadata={"help": "Time scale factor for the trace request generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens for the trace request generator."}, + ) + + @staticmethod + def get_type(): + return RequestGeneratorType.TRACE_REPLAY + + +@dataclass +class BaseReplicaSchedulerConfig(BasePolyConfig): + max_num_seqs: int = field( + default=128, + metadata={"help": "Maximum number of sequences."}, + ) + watermark_blocks_fraction: float = field( + default=0.01, + metadata={"help": "Watermark blocks fraction."}, + ) + block_size: int = field( + default=16, + metadata={"help": "Block size."}, + ) + num_blocks: Optional[int] = field( + default=None, + metadata={"help": "Number of blocks."}, + ) + batch_size_cap: int = field( + default=128, + metadata={"help": "Maximum batch size cap."}, + ) + + +@dataclass +class VllmSchedulerConfig(BaseReplicaSchedulerConfig): + max_batched_tokens: int = field( + default=None, + metadata={"help": "Maximum batched tokens for vLLM."}, + ) + max_tokens_in_batch: int = field( + default=4096, + metadata={"help": "Maximum tokens in batch for vLLM."}, + ) + + @staticmethod + def get_type(): + return ReplicaSchedulerType.VLLM + + +@dataclass +class LightllmSchedulerConfig(BaseReplicaSchedulerConfig): + max_batched_tokens: int = field( + default=None, + metadata={"help": "Maximum batched tokens for LightLLM."}, + ) + max_tokens_in_batch: int = field( + default=4096, + metadata={"help": "Maximum tokens in batch for LightLLM."}, + ) + max_waiting_iters: int = field( + default=10, + metadata={"help": "Maximum waiting iterations for LightLLM."}, + ) + + @staticmethod + def get_type(): + return ReplicaSchedulerType.LIGHTLLM + + +@dataclass +class OrcaSchedulerConfig(BaseReplicaSchedulerConfig): + + @staticmethod + def get_type(): + return ReplicaSchedulerType.ORCA + + +@dataclass +class FasterTransformerSchedulerConfig(BaseReplicaSchedulerConfig): + + @staticmethod + def get_type(): + return ReplicaSchedulerType.FASTER_TRANSFORMER + + +@dataclass +class SarathiSchedulerConfig(BaseReplicaSchedulerConfig): + chunk_size: int = field( + default=512, + metadata={"help": "Chunk size for Sarathi."}, + ) + + @staticmethod + def get_type(): + return ReplicaSchedulerType.SARATHI + + +@dataclass +class MetricsConfig: + """Metric configuration.""" + + write_metrics: bool = field( + default=True, + metadata={"help": "Whether to write metrics."}, + ) + write_json_trace: bool = field( + default=False, + metadata={"help": "Whether to write json trace."}, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases project name."}, + ) + wandb_group: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases group name."}, + ) + wandb_run_name: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases run name."}, + ) + wandb_sweep_id: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases sweep id."}, + ) + wandb_run_id: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases run id."}, + ) + enable_chrome_trace: bool = field( + default=True, + metadata={"help": "Enable Chrome tracing."}, + ) + save_table_to_wandb: bool = field( + default=False, + metadata={"help": "Whether to save table to wandb."}, + ) + store_plots: bool = field( + default=True, + metadata={"help": "Whether to store plots."}, + ) + store_operation_metrics: bool = field( + default=False, + metadata={"help": "Whether to store operation metrics."}, + ) + store_token_completion_metrics: bool = field( + default=False, + metadata={"help": "Whether to store token completion metrics."}, + ) + store_request_metrics: bool = field( + default=True, + metadata={"help": "Whether to store request metrics."}, + ) + store_batch_metrics: bool = field( + default=True, + metadata={"help": "Whether to store batch metrics."}, + ) + store_utilization_metrics: bool = field( + default=True, + metadata={"help": "Whether to store utilization metrics."}, + ) + keep_individual_batch_metrics: bool = field( + default=False, + metadata={"help": "Whether to keep individual batch metrics."}, + ) + subsamples: Optional[int] = field( + default=None, + metadata={"help": "Subsamples."}, + ) + min_batch_index: Optional[int] = field( + default=None, + metadata={"help": "Minimum batch index."}, + ) + max_batch_index: Optional[int] = field( + default=None, + metadata={"help": "Maximum batch index."}, + ) + output_dir: str = field( + default="simulator_output", + metadata={"help": "Output directory."}, + ) + cache_dir: str = field( + default="cache", + metadata={"help": "Cache directory."}, + ) + + def __post_init__(self): + self.output_dir = ( + f"{self.output_dir}/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}" + ) + os.makedirs(self.output_dir, exist_ok=True) + + +@dataclass +class ReplicaConfig: + model_name: str = field( + default="meta-llama/Llama-2-7b-hf", + metadata={"help": "Model name."}, + ) + gpu_memory_utilization: float = field( + default=0.8, + metadata={"help": "GPU memory utilization."}, + ) + memory_margin_fraction: float = field( + default=0.1, + metadata={"help": "Memory margin fraction."}, + ) + num_pipeline_stages: int = field( + default=4, + metadata={"help": "Number of pipeline stages."}, + ) + tensor_parallel_size: int = field( + default=1, + metadata={"help": "Tensor parallel size."}, + ) + device: str = field( + default="a100", + metadata={"help": "Device."}, + ) + network_device: str = field( + default="a100_pairwise_nvlink", + metadata={"help": "Network device."}, + ) + + def __post_init__(self): + self.world_size = self.num_pipeline_stages * self.tensor_parallel_size + self.model_config: BaseModelConfig = BaseModelConfig.create_from_name( + self.model_name + ) + self.device_config: BaseDeviceSKUConfig = ( + BaseDeviceSKUConfig.create_from_type_string(self.device) + ) + self.node_config: BaseNodeSKUConfig = BaseNodeSKUConfig.create_from_type_string( + self.network_device + ) + + +@dataclass +class BaseGlobalSchedulerConfig(BasePolyConfig): + pass + + +@dataclass +class RandomGlobalSchedulerConfig(BaseGlobalSchedulerConfig): + @staticmethod + def get_type(): + return GlobalSchedulerType.RANDOM + + +@dataclass +class RoundRobinGlobalSchedulerConfig(BaseGlobalSchedulerConfig): + @staticmethod + def get_type(): + return GlobalSchedulerType.ROUND_ROBIN + + +@dataclass +class LORGlobalSchedulerConfig(BaseGlobalSchedulerConfig): + @staticmethod + def get_type(): + return GlobalSchedulerType.LOR + + +@dataclass +class BaseExecutionTimePredictorConfig(BasePolyConfig): + compute_input_file: str = field( + default="./data/profiling/compute/{DEVICE}/{MODEL}/mlp.csv", + metadata={"help": "Path to the compute input file."}, + ) + attention_input_file: str = field( + default="./data/profiling/compute/{DEVICE}/{MODEL}/attention.csv", + metadata={"help": "Path to the attention input file."}, + ) + all_reduce_input_file: str = field( + default="./data/profiling/network/{NETWORK_DEVICE}/all_reduce.csv", + metadata={"help": "Path to the all reduce input file."}, + ) + send_recv_input_file: str = field( + default="./data/profiling/network/{NETWORK_DEVICE}/send_recv.csv", + metadata={"help": "Path to the send recv input file."}, + ) + cpu_overhead_input_file: str = field( + default="./data/profiling/cpu_overhead/{NETWORK_DEVICE}/{MODEL}/cpu_overheads.csv", + metadata={"help": "Path to the cpu overhead input file."}, + ) + k_fold_cv_splits: int = field( + default=10, + metadata={"help": "Number of k fold cross validation splits."}, + ) + no_cache: bool = field( + default=False, + metadata={"help": "Whether to cache prediction models."}, + ) + kv_cache_prediction_granularity: int = field( + default=64, + metadata={"help": "KV cache prediction granularity."}, + ) + prediction_max_prefill_chunk_size: int = field( + default=4096, + metadata={"help": "Max prefill chunk size for prediction."}, + ) + prediction_max_batch_size: int = field( + default=128, + metadata={"help": "Max batch size for prediction."}, + ) + prediction_max_tokens_per_request: int = field( + default=4096, + metadata={"help": "Max tokens per request for prediction."}, + ) + attention_decode_batching_overhead_fraction: float = field( + default=0.1, + metadata={"help": "Attention decode batching overhead fraction."}, + ) + attention_prefill_batching_overhead_fraction: float = field( + default=0.1, + metadata={"help": "Attention prefill batching overhead fraction."}, + ) + nccl_cpu_launch_overhead_ms: float = field( + default=0.02, + metadata={"help": "NCCL CPU launch overhead in ms."}, + ) + nccl_cpu_skew_overhead_per_device_ms: float = field( + default=0.0, + metadata={"help": "NCCL CPU skew overhead per device in ms."}, + ) + num_training_job_threads: int = field( + default=-1, + metadata={"help": "Number of training job threads."}, + ) + skip_cpu_overhead_modeling: bool = field( + default=True, + metadata={"help": "Whether to skip CPU overhead modeling."}, + ) + + +@dataclass +class LinearRegressionExecutionTimePredictorConfig(BaseExecutionTimePredictorConfig): + polynomial_degree: List[int] = field( + default_factory=lambda: list(range(1, 6)), + metadata={"help": "Polynomial degree for linear regression."}, + ) + polynomial_include_bias: List[bool] = field( + default_factory=lambda: [True, False], + metadata={"help": "Polynomial include bias for linear regression."}, + ) + polynomial_interaction_only: List[bool] = field( + default_factory=lambda: [True, False], + metadata={"help": "Polynomial interaction only for linear regression."}, + ) + fit_intercept: List[bool] = field( + default_factory=lambda: [True, False], + metadata={"help": "Fit intercept for linear regression."}, + ) + + @staticmethod + def get_type(): + return ExecutionTimePredictorType.LINEAR_REGRESSION + + +@dataclass +class RandomForrestExecutionTimePredictorConfig(BaseExecutionTimePredictorConfig): + num_estimators: List[int] = field( + default_factory=lambda: [250, 500, 750], + metadata={"help": "Number of estimators for random forest."}, + ) + max_depth: List[int] = field( + default_factory=lambda: [8, 16, 32], + metadata={"help": "Maximum depth for random forest."}, + ) + min_samples_split: List[int] = field( + default_factory=lambda: [2, 5, 10], + metadata={"help": "Minimum samples split for random forest."}, + ) + + @staticmethod + def get_type(): + return ExecutionTimePredictorType.RANDOM_FORREST + + +@dataclass +class ClusterConfig: + num_replicas: int = field( + default=1, + metadata={"help": "Number of replicas."}, + ) + replica_config: ReplicaConfig = field(default_factory=ReplicaConfig) + global_scheduler_config: BaseGlobalSchedulerConfig = field( + default_factory=RoundRobinGlobalSchedulerConfig, + metadata={"help": "Global scheduler config."}, + ) + replica_scheduler_config: BaseReplicaSchedulerConfig = field( + default_factory=SarathiSchedulerConfig, + metadata={"help": "Replica scheduler config."}, + ) + + +@dataclass +class SimulationConfig(ABC): + seed: int = field( + default=42, + metadata={"help": "Seed for the random number generator."}, + ) + log_level: str = field( + default="info", + metadata={"help": "Logging level."}, + ) + time_limit: int = field( + default=0, # in seconds, 0 is no limit + metadata={"help": "Time limit for simulation in seconds. 0 means no limit."}, + ) + cluster_config: ClusterConfig = field( + default_factory=ClusterConfig, + metadata={"help": "Cluster config."}, + ) + request_generator_config: BaseRequestGeneratorConfig = field( + default_factory=SyntheticRequestGeneratorConfig, + metadata={"help": "Request generator config."}, + ) + execution_time_predictor_config: BaseExecutionTimePredictorConfig = field( + default_factory=RandomForrestExecutionTimePredictorConfig, + metadata={"help": "Execution time predictor config."}, + ) + metrics_config: MetricsConfig = field( + default_factory=MetricsConfig, + metadata={"help": "Metrics config."}, + ) + + def __post_init__(self): + self.write_config_to_file() + + @classmethod + def create_from_cli_args(cls): + flat_config = create_flat_dataclass(cls).create_from_cli_args() + instance = flat_config.reconstruct_original_dataclass() + instance.__flat_config__ = flat_config + return instance def to_dict(self): - return self._args.__dict__ - - def _add_to_args(self, new_args_dict, parent_key=""): - for key, value in new_args_dict.items(): - arg_name = f"{parent_key}{key}" - setattr(self._args, arg_name, value) - - def _load_model_config(self): - assert self.replica_model_name is not None - - config_file = f"{MODEL_CONFIG_DIR}/{self.replica_model_name}.yml" - with open(config_file, "r") as file: - yaml_config = yaml.safe_load(file) - self._add_to_args(yaml_config, "replica_") - - def _load_device_config(self): - assert self.replica_device is not None - - config_file = f"{DEVICE_CONFIG_DIR}/{self.replica_device}.yml" - with open(config_file, "r") as file: - yaml_config = yaml.safe_load(file) - self._add_to_args(yaml_config, "replica_") - - def _substitute_variables_in_args(self): - assert self.replica_model_name is not None - assert self.replica_device is not None - assert self.replica_network_device is not None - - # update names of sklearn config files - for key, value in self._args.__dict__.items(): - if isinstance(value, str): - self._args.__dict__[key] = ( - value.replace("{MODEL}", self.replica_model_name) - .replace("{DEVICE}", self.replica_device) - .replace("{NETWORK_DEVICE}", self.replica_network_device) - ) + if not hasattr(self, "__flat_config__"): + logger.warning("Flat config not found. Returning the original config.") + return self.__dict__ + + return self.__flat_config__.__dict__ + + def write_config_to_file(self): + config_dict = dataclass_to_dict(self) + with open(f"{self.metrics_config.output_dir}/config.json", "w") as f: + json.dump(config_dict, f, indent=4) diff --git a/vidur/config/default.yml b/vidur/config/default.yml deleted file mode 100644 index 1913499f..00000000 --- a/vidur/config/default.yml +++ /dev/null @@ -1,162 +0,0 @@ -seed: 42 -log_level: info -output_dir: ./simulator_output/ -cache_dir: ./cache -write_json_trace: false -write_chrome_trace: true -write_metrics: true - -cluster: - num_replicas: 1 - -replica: - block_size: 16 - memory_margin_fraction: 0.1 - num_pipeline_stages: 4 - num_tensor_parallel_workers: 1 - model_name: meta-llama/Llama-2-7b-hf - device: a100 - network_device: a100_pair_nvlink - -request_generator: - provider: synthetic - max_tokens: 4096 - -synthetic_request_generator: - length_provider: trace - interval_provider: static - min_tokens: 1024 - prefill_to_decode_ratio: 10 - num_requests: 128 - -trace_request_generator: - trace_file: ./data/processed_traces/sydney_enterprise.csv - date: '2023-08-21' - prefill_scale_factor: 0.3 - decode_scale_factor: 1 - time_scale_factor: 0.04 - -# Config for synthetic trace generator -trace_request_length_generator: - trace_file: ./data/processed_traces/arxiv_summarization_stats_llama2_tokenizer_filtered_v2.csv - prefill_scale_factor: 1 - decode_scale_factor: 1 - -trace_request_interval_generator: - trace_file: ./data/processed_traces/AzureFunctionsInvocationTraceForTwoWeeksJan2021Processed.csv - start_time: "1970-01-04 12:00:00" - end_time: "1970-01-04 15:00:00" - time_scale_factor: 0.3 - -poisson_request_interval_generator: - qps: 0.5 - -gamma_request_interval_generator: - cv: 0.5 - qps: 0.2 - -zipf_request_length_generator: - theta: 0.4 - scramble: false - -fixed_request_generator: - prefill_tokens: 2048 - decode_tokens: 512 - -execution_time_predictor: - provider: random_forrest - # provider: linear_regression - -sklearn_execution_time_predictor: - compute_input_file: ./data/profiling/compute/{DEVICE}/{MODEL}/mlp.csv - attention_input_file: ./data/profiling/compute/{DEVICE}/{MODEL}/attention.csv - all_reduce_input_file: ./data/profiling/network/{NETWORK_DEVICE}/all_reduce.csv - send_recv_input_file: ./data/profiling/network/{NETWORK_DEVICE}/send_recv.csv - cpu_overhead_input_file: ./data/profiling/cpu_overhead/{NETWORK_DEVICE}/{MODEL}/cpu_overheads.csv - k_fold_cv_splits: 10 - no_cache: false - kv_cache_prediction_granularity: 64 - prediction_max_prefill_chunk_size: 4096 - prediction_max_batch_size: 128 - prediction_max_tokens_per_request: 4096 - attention_decode_batching_overhead_fraction: 0.1 - attention_prefill_batching_overhead_fraction: 0.1 - nccl_cpu_launch_overhead_ms: 0.020 - nccl_cpu_skew_overhead_per_device_ms: 0.0 - num_training_job_threads: -1 - skip_cpu_overhead_modeling: true - -random_forrest_execution_time_predictor: - num_estimators: - # - 250 - - 500 - - 750 - max_depth: - # - 8 - # - 16 - - 32 - min_samples_split: - - 2 - - 5 - - 10 - -linear_regression_execution_time_predictor: - polynomial_degree: - - 1 - - 2 - - 3 - - 4 - - 5 - polynomial_include_bias: - - true - - false - polynomial_interaction_only: - - true - - false - fit_intercept: - - true - - false - -simulator: - time_limit: 0 # in seconds, 0 is no limit - -global_scheduler: - provider: round_robin - -replica_scheduler: - provider: sarathi - batch_size_cap: 128 - num_blocks: null - -orca_scheduler: - use_single_prefill_per_batch: false - -vllm_scheduler: - watermark_blocks_fraction: 0.01 - max_tokens_in_batch: 4096 - -sarathi_scheduler: - chunk_size: 512 - enable_rolling_prefills: true - prefill_fitting_tolerance: 0.0 - watermark_blocks_fraction: 0.01 - -lightllm_scheduler: - max_tokens_in_batch: 4096 - max_waiting_iters: 10 - -metrics_store: - wandb_project: "llm-simulator-v2" - wandb_group: "" - wandb_run_name: "" - subsamples: null - save_table_to_wandb: false - store_plots: true - store_operation_metrics: false - store_token_completion_metrics: false - store_request_metrics: true - store_batch_metrics: true - store_utilization_metrics: true - keep_individual_batch_metrics: false - # min_batch_idx: 2000 - # max_batch_idx: 5000 diff --git a/vidur/config/device_sku_config.py b/vidur/config/device_sku_config.py new file mode 100644 index 00000000..a92646fc --- /dev/null +++ b/vidur/config/device_sku_config.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass, field + +from vidur.config.base_fixed_config import BaseFixedConfig +from vidur.logger import init_logger +from vidur.types import DeviceSKUType + +logger = init_logger(__name__) + + +@dataclass +class BaseDeviceSKUConfig(BaseFixedConfig): + fp16_tflops: int + total_memory_gb: int + + +@dataclass +class A100DeviceSKUConfig(BaseDeviceSKUConfig): + fp16_tflops: int = 312 + total_memory_gb: int = 80 + + @staticmethod + def get_type(): + return DeviceSKUType.A40 + + +@dataclass +class A40DeviceSKUConfig(BaseDeviceSKUConfig): + fp16_tflops: int = 150 + total_memory_gb: int = 45 + + @staticmethod + def get_type(): + return DeviceSKUType.A100 + + +@dataclass +class H100DeviceSKUConfig(BaseDeviceSKUConfig): + fp16_tflops: int = 1000 + total_memory_gb: int = 80 + + @staticmethod + def get_type(): + return DeviceSKUType.H100 diff --git a/vidur/config/flat_dataclass.py b/vidur/config/flat_dataclass.py new file mode 100644 index 00000000..fe23dac5 --- /dev/null +++ b/vidur/config/flat_dataclass.py @@ -0,0 +1,219 @@ +import json +from argparse import ( + ArgumentDefaultsHelpFormatter, + ArgumentParser, + BooleanOptionalAction, +) +from collections import defaultdict, deque +from dataclasses import MISSING, fields, make_dataclass +from typing import Any, get_args + +from vidur.config.base_poly_config import BasePolyConfig +from vidur.config.utils import ( + get_all_subclasses, + get_inner_type, + is_bool, + is_composed_of_primitives, + is_dict, + is_list, + is_optional, + is_primitive_type, + is_subclass, + to_snake_case, +) + + +def topological_sort(dataclass_dependencies: dict) -> list: + in_degree = defaultdict(int) + for cls, dependencies in dataclass_dependencies.items(): + for dep in dependencies: + in_degree[dep] += 1 + + zero_in_degree_classes = deque( + [cls for cls in dataclass_dependencies if in_degree[cls] == 0] + ) + sorted_classes = [] + + while zero_in_degree_classes: + cls = zero_in_degree_classes.popleft() + sorted_classes.append(cls) + for dep in dataclass_dependencies[cls]: + in_degree[dep] -= 1 + if in_degree[dep] == 0: + zero_in_degree_classes.append(dep) + + return sorted_classes + + +def reconstruct_original_dataclass(self) -> Any: + """ + This function is dynamically mapped to FlatClass as an instance method. + """ + sorted_classes = topological_sort(self.dataclass_dependencies) + instances = {} + + for _cls in reversed(sorted_classes): + args = {} + + for prefixed_field_name, original_field_name, field_type in self.dataclass_args[ + _cls + ]: + if is_subclass(field_type, BasePolyConfig): + config_type = getattr(self, f"{original_field_name}_type") + # find all subclasses of field_type and check which one matches the config_type + for subclass in get_all_subclasses(field_type): + if str(subclass.get_type()) == config_type: + args[original_field_name] = instances[subclass] + break + elif hasattr(field_type, "__dataclass_fields__"): + args[original_field_name] = instances[field_type] + else: + value = getattr(self, prefixed_field_name) + if callable(value): + # to handle default factory values + value = value() + args[original_field_name] = value + + instances[_cls] = _cls(**args) + + return instances[sorted_classes[0]] + + +@classmethod +def create_from_cli_args(cls) -> Any: + """ + This function is dynamically mapped to FlatClass as a class method. + """ + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + + for field in fields(cls): + nargs = None + action = None + field_type = field.type + help_text = field.metadata.get("help", None) + + if is_list(field.type): + assert is_composed_of_primitives(field.type) + field_type = get_args(field.type)[0] + if is_primitive_type(field_type): + nargs = "+" + else: + field_type = json.loads + elif is_dict(field.type): + assert is_composed_of_primitives(field.type) + field_type = json.loads + elif is_bool(field.type): + action = BooleanOptionalAction + + arg_params = { + "type": field_type, + "action": action, + "help": help_text, + } + + # handle cases with default and default factory args + if field.default is not MISSING: + arg_params["default"] = field.default + elif field.default_factory is not MISSING: + arg_params["default"] = field.default_factory() + else: + arg_params["required"] = True + + if nargs: + arg_params["nargs"] = nargs + + parser.add_argument(f"--{field.name}", **arg_params) + + args = parser.parse_args() + + return cls(**vars(args)) + + +def create_flat_dataclass(input_dataclass: Any) -> Any: + """ + Creates a new FlatClass type by recursively flattening the input dataclass. + This allows for easy parsing of command line arguments along with storing/loading the configuration to/from a file. + """ + meta_fields_with_defaults = [] + meta_fields_without_defaults = [] + processed_classes = set() + dataclass_args = defaultdict(list) + dataclass_dependencies = defaultdict(set) + + def process_dataclass(_input_dataclass, prefix=""): + if _input_dataclass in processed_classes: + return + + processed_classes.add(_input_dataclass) + + for field in fields(_input_dataclass): + prefixed_name = f"{prefix}{field.name}" + + if is_optional(field.type): + field_type = get_inner_type(field.type) + else: + field_type = field.type + + # # if field is a BasePolyConfig, add a type argument and process it as a dataclass + if is_subclass(field_type, BasePolyConfig): + dataclass_args[_input_dataclass].append( + (field.name, field.name, field_type) + ) + + type_field_name = f"{field.name}_type" + default_value = str(field.default_factory().get_type()) + meta_fields_with_defaults.append( + (type_field_name, type(default_value), default_value) + ) + + assert hasattr(field_type, "__dataclass_fields__") + for subclass in get_all_subclasses(field_type): + dataclass_dependencies[_input_dataclass].add(subclass) + process_dataclass(subclass, f"{to_snake_case(subclass.__name__)}_") + continue + + # if field is a dataclass, recursively process it + if hasattr(field_type, "__dataclass_fields__"): + dataclass_dependencies[_input_dataclass].add(field_type) + dataclass_args[_input_dataclass].append( + (field.name, field.name, field_type) + ) + process_dataclass(field_type, f"{to_snake_case(field_type.__name__)}_") + continue + + field_default = field.default if field.default is not MISSING else MISSING + field_default_factory = ( + field.default_factory + if field.default_factory is not MISSING + else MISSING + ) + + if field_default is not MISSING: + meta_fields_with_defaults.append( + (prefixed_name, field_type, field_default) + ) + elif field_default_factory is not MISSING: + meta_fields_with_defaults.append( + (prefixed_name, field_type, field_default_factory) + ) + else: + meta_fields_without_defaults.append((prefixed_name, field_type)) + + dataclass_args[_input_dataclass].append( + (prefixed_name, field.name, field_type) + ) + + process_dataclass(input_dataclass) + + meta_fields = meta_fields_without_defaults + meta_fields_with_defaults + FlatClass = make_dataclass("FlatClass", meta_fields) + + # Metadata fields + FlatClass.dataclass_args = dataclass_args + FlatClass.dataclass_dependencies = dataclass_dependencies + + # Helper methods + FlatClass.reconstruct_original_dataclass = reconstruct_original_dataclass + FlatClass.create_from_cli_args = create_from_cli_args + + return FlatClass diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py new file mode 100644 index 00000000..0057d782 --- /dev/null +++ b/vidur/config/model_config.py @@ -0,0 +1,191 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from vidur.config.base_fixed_config import BaseFixedConfig +from vidur.logger import init_logger +from vidur.types import ActivationType, NormType + +logger = init_logger(__name__) + + +@dataclass +class BaseModelConfig(BaseFixedConfig): + num_layers: int + num_q_heads: int + num_kv_heads: int + embedding_dim: int + mlp_hidden_dim: int + max_position_embeddings: int + use_gated_mlp: bool + use_bias: bool + use_qkv_bias: bool + activation: ActivationType + norm: NormType + post_attn_norm: bool + vocab_size: int + is_neox_style: Optional[bool] = True + rope_theta: Optional[int] = None + rope_scaling: Optional[Dict[str, Any]] = None + partial_rotary_factor: float = 1.0 + no_tensor_parallel: bool = False + + +@dataclass +class Llama2ModelConfig(BaseModelConfig): + max_position_embeddings: int = 16384 + use_gated_mlp: bool = True + use_bias: bool = False + use_qkv_bias: bool = False + activation: ActivationType = ActivationType.SILU + norm: NormType = NormType.RMS_NORM + post_attn_norm: bool = True + vocab_size: int = 32768 + is_neox_style: Optional[bool] = True + rope_theta: Optional[int] = 10000.0 + rope_scaling: Optional[Dict[str, Any]] = None + partial_rotary_factor: float = 1.0 + no_tensor_parallel: bool = False + + @staticmethod + def get_name(): + return "meta-llama/Llama-2-Config" + + +@dataclass +class CodeLlama34BModelConfig(Llama2ModelConfig): + num_layers: int = 48 + num_q_heads: int = 64 + num_kv_heads: int = 8 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 22016 + + @staticmethod + def get_name(): + return "codellama/CodeLlama-34b-Instruct-hf" + + +@dataclass +class Llama2_7BModelConfig(Llama2ModelConfig): + num_layers: int = 32 + num_q_heads: int = 32 + num_kv_heads: int = 32 + embedding_dim: int = 4096 + mlp_hidden_dim: int = 11008 + + @staticmethod + def get_name(): + return "meta-llama/Llama-2-7b-hf" + + +@dataclass +class Llama2_70BModelConfig(Llama2ModelConfig): + num_layers: int = 80 + num_q_heads: int = 64 + num_kv_heads: int = 8 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 28672 + + @staticmethod + def get_name(): + return "meta-llama/Llama-2-70b-hf" + + +@dataclass +class Llama3_8BModelConfig(Llama2ModelConfig): + num_layers: int = 32 + num_q_heads: int = 32 + num_kv_heads: int = 8 + embedding_dim: int = 4096 + mlp_hidden_dim: int = 14336 + max_position_embeddings: int = 4096 + rope_theta: Optional[int] = 500000.0 + vocab_size: int = 128256 + + @staticmethod + def get_name(): + return "meta-llama/Meta-Llama-3-8B" + + +@dataclass +class Llama3_70BModelConfig(Llama2ModelConfig): + num_layers: int = 80 + num_q_heads: int = 64 + num_kv_heads: int = 8 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 28672 + max_position_embeddings: int = 8192 + rope_theta: Optional[int] = 500000.0 + vocab_size: int = 128256 + + @staticmethod + def get_name(): + return "meta-llama/Meta-Llama-3-70B" + + +@dataclass +class InternLM2ModelConfig(Llama2ModelConfig): + max_position_embeddings: int = 32768 + vocab_size: int = 92544 + + +@dataclass +class InternLM2_20BModelConfig(InternLM2ModelConfig): + num_layers: int = 48 + num_q_heads: int = 48 + num_kv_heads: int = 8 + embedding_dim: int = 6144 + mlp_hidden_dim: int = 16384 + + @staticmethod + def get_name(): + return "internlm/internlm2-20b" + + +@dataclass +class Phi2ModelConfig(Llama2ModelConfig): + num_layers: int = 32 + num_q_heads: int = 32 + num_kv_heads: int = 32 + embedding_dim: int = 2560 + mlp_hidden_dim: int = 10240 + max_position_embeddings: int = 2048 + use_gated_mlp: bool = False + use_bias: bool = True + use_qkv_bias: bool = True + activation: ActivationType = ActivationType.GELU + norm: NormType = NormType.LAYER_NORM + post_attn_norm: bool = False + vocab_size: int = 51200 + rope_scaling: Optional[Dict[str, Any]] = None + rope_theta: Optional[int] = 10000.0 + partial_rotary_factor: float = 0.4 + no_tensor_parallel: bool = True + is_neox_style: bool = True + + @staticmethod + def get_name(): + return "microsoft/phi-2" + + +@dataclass +class QwenModelConfig(Llama2ModelConfig): + use_qkv_bias: bool = True + max_position_embeddings: int = 32768 + vocab_size: int = 152064 + + @staticmethod + def get_name(): + return "Qwen/Qwen-Config" + + +@dataclass +class Qwen72BModelConfig(QwenModelConfig): + num_layers: int = 80 + num_q_heads: int = 64 + num_kv_heads: int = 64 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 24576 + + @staticmethod + def get_name(): + return "Qwen/Qwen-72B" diff --git a/vidur/config/node_sku_config.py b/vidur/config/node_sku_config.py new file mode 100644 index 00000000..34eb8050 --- /dev/null +++ b/vidur/config/node_sku_config.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass, field + +from vidur.config.base_fixed_config import BaseFixedConfig +from vidur.logger import init_logger +from vidur.types import DeviceSKUType, NodeSKUType + +logger = init_logger(__name__) + + +@dataclass +class BaseNodeSKUConfig(BaseFixedConfig): + num_devices_per_node: int + + +@dataclass +class A40PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.A40 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.A40_PAIRWISE_NVLINK + + +@dataclass +class A100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.A100 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.A100_PAIRWISE_NVLINK + + +@dataclass +class H100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.H100 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.H100_PAIRWISE_NVLINK + + +@dataclass +class A100DgxNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.A100 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.A100_DGX + + +@dataclass +class H100DgxNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.H100 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.H100_DGX diff --git a/vidur/config/utils.py b/vidur/config/utils.py new file mode 100644 index 00000000..8627dd34 --- /dev/null +++ b/vidur/config/utils.py @@ -0,0 +1,87 @@ +from dataclasses import fields, is_dataclass +from typing import Union, get_args, get_origin + +primitive_types = {int, str, float, bool, type(None)} + + +def get_all_subclasses(cls): + subclasses = cls.__subclasses__() + return subclasses + [g for s in subclasses for g in get_all_subclasses(s)] + + +def is_primitive_type(field_type: type) -> bool: + # Check if the type is a primitive type + return field_type in primitive_types + + +def is_generic_composed_of_primitives(field_type: type) -> bool: + origin = get_origin(field_type) + if origin in {list, dict, tuple, Union}: + # Check all arguments of the generic type + args = get_args(field_type) + return all(is_composed_of_primitives(arg) for arg in args) + return False + + +def is_composed_of_primitives(field_type: type) -> bool: + # Check if the type is a primitive type + if is_primitive_type(field_type): + return True + + # Check if the type is a generic type composed of primitives + if is_generic_composed_of_primitives(field_type): + return True + + return False + + +def to_snake_case(name: str) -> str: + return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_") + + +def is_optional(field_type: type) -> bool: + return get_origin(field_type) is Union and type(None) in get_args(field_type) + + +def is_list(field_type: type) -> bool: + # Check if the field type is a List + return get_origin(field_type) is list + + +def is_dict(field_type: type) -> bool: + # Check if the field type is a Dict + return get_origin(field_type) is dict + + +def is_bool(field_type: type) -> bool: + return field_type is bool + + +def get_inner_type(field_type: type) -> type: + return next(t for t in get_args(field_type) if t is not type(None)) + + +def is_subclass(cls, parent: type) -> bool: + return hasattr(cls, "__bases__") and parent in cls.__bases__ + + +def dataclass_to_dict(obj): + if isinstance(obj, list): + return [dataclass_to_dict(item) for item in obj] + elif is_dataclass(obj): + data = {} + for field in fields(obj): + value = getattr(obj, field.name) + data[field.name] = dataclass_to_dict(value) + # Include members created in __post_init__ + for key, value in obj.__dict__.items(): + if key not in data: + data[key] = dataclass_to_dict(value) + # Include the name of the class + if hasattr(obj, "get_type") and callable(getattr(obj, "get_type")): + data["name"] = str(obj.get_type()) + elif hasattr(obj, "get_name") and callable(getattr(obj, "get_name")): + data["name"] = obj.get_name() + return data + else: + return obj diff --git a/vidur/config_optimizer/config_explorer/capacity_search.py b/vidur/config_optimizer/config_explorer/capacity_search.py index bd970f79..74c7fd51 100644 --- a/vidur/config_optimizer/config_explorer/capacity_search.py +++ b/vidur/config_optimizer/config_explorer/capacity_search.py @@ -1,6 +1,7 @@ import argparse import glob import os +import platform import shlex from subprocess import Popen @@ -48,12 +49,11 @@ def _generate_run_command( scheduler_config: SimulationConfig, ): cpu_affinity_command = "" - if self.cpu_core_id is not None: - self.cpu_core_id = self.cpu_core_id + if self.cpu_core_id is not None and platform.system() != "Darwin": cpu_affinity_command = f"taskset --cpu-list {self.cpu_core_id}" command = f"nice -n 1 {cpu_affinity_command} python -m vidur.main {scheduler_config.to_args()}" - logger.debug(f"Running command: {command}", flush=True) + logger.debug(f"Running command: {command}") return command @@ -81,7 +81,6 @@ def _is_under_sla( logger.info( f"{simulator_config.to_human_readable_name()} - Scheduling delay (P{self.args.scheduling_delay_slo_quantile}): {scheduling_delay}", - flush=True, ) return is_under_scheduling_delay_sla, scheduling_delay @@ -120,7 +119,6 @@ def is_under_sla(self, qps: float) -> tuple[bool, float]: except Exception as e: logger.error( f"Error running: {self.job_config.get_human_readable_name()}, failed with error: {e}", - flush=True, ) return False, None @@ -130,7 +128,6 @@ def search(self): """ logger.info( f"Starting search for {self.job_config.get_human_readable_name()}", - flush=True, ) left = 0 @@ -175,7 +172,6 @@ def search(self): logger.info( f"Max QPS under SLO for {self.job_config.get_human_readable_name()}: {max_qps_under_sla}", - flush=True, ) self.release_cpu_core_id() diff --git a/vidur/config_optimizer/config_explorer/config/config.py b/vidur/config_optimizer/config_explorer/config/config.py index db4da92f..4a073fed 100644 --- a/vidur/config_optimizer/config_explorer/config/config.py +++ b/vidur/config_optimizer/config_explorer/config/config.py @@ -15,7 +15,7 @@ def get_key(self): def to_config_dict(self): return { - "replica_model_name": self.identifier, + "replica_config_model_name": self.identifier, } def is_tensor_parallel_degree_valid(self, tp_degree: int): @@ -35,15 +35,20 @@ def get_key(self): def to_config_dict(self): return { - "request_generator_provider": "synthetic", - "synthetic_request_generator_length_provider": "trace", - "synthetic_request_generator_interval_provider": "poisson", - "request_generator_max_tokens": self.max_seq_len, - "trace_request_length_generator_trace_file": self.trace_file, - "trace_request_length_generator_prefill_scale_factor": 1, - "trace_request_length_generator_decode_scale_factor": 1, - "synthetic_request_generator_num_requests": self.num_requests, - "vllm_scheduler_max_tokens_in_batch": self.max_seq_len, + "request_generator_config_type": "synthetic", + "length_generator_config_type": "trace", + "interval_generator_config_type": "poisson", + "synthetic_request_generator_config_max_tokens": self.max_seq_len, + "trace_request_length_generator_config_max_tokens": self.max_seq_len, + "zipf_request_length_generator_config_max_tokens": self.max_seq_len, + "uniform_request_length_generator_config_max_tokens": self.max_seq_len, + "fixed_request_length_generator_config_max_tokens": self.max_seq_len, + "trace_request_generator_config_max_tokens": self.max_seq_len, + "trace_request_length_generator_config_trace_file": self.trace_file, + "trace_request_length_generator_config_prefill_scale_factor": 1, + "trace_request_length_generator_config_decode_scale_factor": 1, + "synthetic_request_generator_config_num_requests": self.num_requests, + "vllm_scheduler_config_max_tokens_in_batch": self.max_seq_len, } @@ -58,7 +63,7 @@ def get_key(self): def to_config_dict(self): return { - "replica_device": self.device, + "replica_config_device": self.device, } @@ -78,16 +83,14 @@ def get_key(self): def to_config_dict(self): if self.scheduler == "vllm": return { - "replica_scheduler_provider": "vllm", + "replica_scheduler_config_type": "vllm", } assert self.scheduler == "sarathi" assert self.chunk_size is not None return { - "replica_scheduler_provider": "sarathi", - "sarathi_scheduler_chunk_size": self.chunk_size, - "sarathi_scheduler_enable_rolling_prefills": None, - "sarathi_scheduler_prefill_fitting_tolerance": 0.0, + "replica_scheduler_config_type": "sarathi", + "sarathi_scheduler_config_chunk_size": self.chunk_size, } @@ -145,10 +148,14 @@ def to_config_dict(self): **self.trace_config.to_config_dict(), **self.cluster_config.to_config_dict(), **self.scheduler_config.to_config_dict(), - "replica_num_tensor_parallel_workers": self.num_tensor_parallel_workers, - "replica_num_pipeline_stages": self.num_pipeline_stages, - "replica_scheduler_batch_size_cap": self.batch_size, - "cluster_num_replicas": self.num_replicas, + "replica_config_tensor_parallel_size": self.num_tensor_parallel_workers, + "replica_config_num_pipeline_stages": self.num_pipeline_stages, + "vllm_scheduler_config_batch_size_cap": self.batch_size, + "lightllm_scheduler_config_batch_size_cap": self.batch_size, + "orca_scheduler_config_batch_size_cap": self.batch_size, + "faster_transformer_scheduler_config_batch_size_cap": self.batch_size, + "sarathi_scheduler_config_batch_size_cap": self.batch_size, + "cluster_config_num_replicas": self.num_replicas, } @classmethod @@ -232,16 +239,18 @@ class SimulationConfig: def to_config_dict(self): return { **self.job_config.to_config_dict(), - "output_dir": self.get_run_dir(), - "cache_dir": self.cache_dir, - "poisson_request_interval_generator_qps": self.qps, - "simulator_time_limit": self.time_limit * 60, # to seconds - "no-metrics_store_save_table_to_wandb": None, - "no-metrics_store_store_plots": None, - "no-metrics_store_store_operation_metrics": None, - "no-metrics_store_store_token_completion_metrics": None, - "no-write_chrome_trace": None, - "sklearn_execution_time_predictor_skip_cpu_overhead_modeling": None, + "metrics_config_output_dir": self.get_run_dir(), + "metrics_config_cache_dir": self.cache_dir, + "poisson_request_interval_generator_config_qps": self.qps, + "gamma_request_interval_generator_config_qps": self.qps, + "time_limit": self.time_limit * 60, # to seconds + "no-metrics_config_save_table_to_wandb": None, + "no-metrics_config_store_plots": None, + "no-metrics_config_store_operation_metrics": None, + "no-metrics_config_store_token_completion_metrics": None, + "no-metrics_config_enable_chrome_trace": None, + "linear_regression_execution_time_predictor_config_skip_cpu_overhead_modeling": None, + "random_forrest_execution_time_predictor_config_skip_cpu_overhead_modeling": None, } def to_args(self): diff --git a/vidur/config_optimizer/config_explorer/main.py b/vidur/config_optimizer/config_explorer/main.py index 26ce5af9..7821c9a5 100644 --- a/vidur/config_optimizer/config_explorer/main.py +++ b/vidur/config_optimizer/config_explorer/main.py @@ -64,9 +64,9 @@ def get_args(): os.makedirs(args.output_dir, exist_ok=True) - logger.info("Starting config optimizer", flush=True) - logger.info(f"Args: {args}", flush=True) - logger.info(f"Config: {config}", flush=True) + logger.info("Starting config optimizer") + logger.info(f"Args: {args}") + logger.info(f"Config: {config}") # store the config and args json.dump(vars(args), open(f"{args.output_dir}/args.json", "w")) @@ -80,4 +80,4 @@ def get_args(): end_time = time.time() - logger.info(f"Simulation took time: {end_time - start_time}", flush=True) + logger.info(f"Simulation took time: {end_time - start_time}") diff --git a/vidur/config_optimizer/config_explorer/ray_utils.py b/vidur/config_optimizer/config_explorer/ray_utils.py index 1b4bbd4c..35a5eb73 100644 --- a/vidur/config_optimizer/config_explorer/ray_utils.py +++ b/vidur/config_optimizer/config_explorer/ray_utils.py @@ -1,4 +1,5 @@ import os +import platform import socket import time from typing import Optional @@ -7,7 +8,20 @@ def get_ip() -> str: - return socket.gethostbyname(socket.gethostname()) + # special handling for macos + if platform.system() == "Darwin": + return "127.0.0.1" + + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(0) + try: + s.connect(("10.254.254.254", 1)) + ip = s.getsockname()[0] + except Exception: + ip = "127.0.0.1" + finally: + s.close() + return ip def get_nodes() -> list[str]: @@ -17,6 +31,11 @@ def get_nodes() -> list[str]: for x in cluster_resources_keys if x.startswith("node:") and x != "node:__internal_head__" ] + + # special handling for macos, ensure that we only have one node + if platform.system() == "Darwin": + assert len(ip_addresses) == 1 + return ip_addresses diff --git a/vidur/constants.py b/vidur/constants.py deleted file mode 100644 index e1dfca74..00000000 --- a/vidur/constants.py +++ /dev/null @@ -1,8 +0,0 @@ -import os - -PY_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) - -DEFAULT_CONFIG_FILE = f"{PY_ROOT_DIR}/config/default.yml" -MODEL_CONFIG_DIR = f"{PY_ROOT_DIR}/../data/model_configs" -DEVICE_CONFIG_DIR = f"{PY_ROOT_DIR}/../data/device_configs" -CACHE_DIR = f"{PY_ROOT_DIR}/../.simulator_cache" diff --git a/vidur/entities/cluster.py b/vidur/entities/cluster.py index 34a8f48e..013b86d8 100644 --- a/vidur/entities/cluster.py +++ b/vidur/entities/cluster.py @@ -1,6 +1,6 @@ import json -from vidur.config import Config +from vidur.config import BaseRequestGeneratorConfig, ClusterConfig, MetricsConfig from vidur.entities.base_entity import BaseEntity from vidur.entities.replica import Replica from vidur.logger import init_logger @@ -9,18 +9,26 @@ class Cluster(BaseEntity): - def __init__(self, config: Config): + def __init__( + self, + cluster_config: ClusterConfig, + metrics_config: MetricsConfig, + generator_config: BaseRequestGeneratorConfig, + ) -> None: self._id = Cluster.generate_id() - self._config = config + self._config = cluster_config + + # get metrics config + self._output_dir = metrics_config.output_dir # Init replica object handles self._replicas = {} - for _ in range(self._config.cluster_num_replicas): - replica = Replica(config) + for _ in range(self._config.num_replicas): + replica = Replica(self._config.replica_config, generator_config) self._replicas[replica.id] = replica - if self._config.write_json_trace: + if metrics_config.write_json_trace: self._write_cluster_info_to_file() @property @@ -37,6 +45,6 @@ def _write_cluster_info_to_file(self) -> None: replica_dicts = [replica.to_dict() for replica in self._replicas.values()] cluster_info = {"replicas": replica_dicts} - cluster_file = f"{self._config.output_dir}/cluster.json" + cluster_file = f"{self._output_dir}/cluster.json" with open(cluster_file, "w") as f: json.dump(cluster_info, f) diff --git a/vidur/entities/replica.py b/vidur/entities/replica.py index d0ec553b..bda2293e 100644 --- a/vidur/entities/replica.py +++ b/vidur/entities/replica.py @@ -1,6 +1,6 @@ from math import ceil -from vidur.config import Config +from vidur.config import BaseRequestGeneratorConfig, ReplicaConfig from vidur.entities.base_entity import BaseEntity from vidur.logger import init_logger @@ -8,107 +8,113 @@ class Replica(BaseEntity): - def __init__(self, config: Config) -> None: - assert config.replica_num_layers % config.replica_num_pipeline_stages == 0 + def __init__( + self, + replica_config: ReplicaConfig, + generator_config: BaseRequestGeneratorConfig, + ) -> None: + self._id = Replica.generate_id() + + self._replica_config = replica_config + self._model_config = replica_config.model_config + self._device_config = replica_config.device_config + self._generator_config = generator_config + assert ( - config.replica_embedding_dim % config.replica_num_tensor_parallel_workers + self._model_config.num_layers % self._replica_config.num_pipeline_stages + == 0 + ) + assert ( + self._model_config.embedding_dim % self._replica_config.tensor_parallel_size == 0 ) - self._id = Replica.generate_id() - - self._num_pipeline_stages = config.replica_num_pipeline_stages - self._num_tensor_parallel_workers = config.replica_num_tensor_parallel_workers - self._num_layers = config.replica_num_layers - self._num_q_heads = config.replica_num_q_heads - self._num_kv_heads = config.replica_num_kv_heads - self._embedding_dim = config.replica_embedding_dim - self._mlp_hidden_dim = config.replica_mlp_hidden_dim - self._use_gated_mlp = config.replica_use_gated_mlp - self._vocab_size = config.replica_vocab_size - self._total_memory_gb = config.replica_total_memory_gb - self._memory_margin_fraction = config.replica_memory_margin_fraction - self._max_request_tokens = config.request_generator_max_tokens - self._per_device_flops = config.replica_fp16_tflops * 2**40 + @property + def id(self) -> int: + return self._id @property def num_layers(self) -> int: - return self._num_layers + return self._model_config.num_layers @property def num_q_heads(self) -> int: - return self._num_q_heads + return self._model_config.num_q_heads @property def num_kv_heads(self) -> int: - return self._num_kv_heads + return self._model_config.num_kv_heads @property def embedding_dim(self) -> int: - return self._embedding_dim + return self._model_config.embedding_dim @property def mlp_hidden_dim(self) -> int: - return self._mlp_hidden_dim + return self._model_config.mlp_hidden_dim @property def use_gated_mlp(self) -> int: - return self._use_gated_mlp + return self._model_config.use_gated_mlp @property def vocab_size(self) -> int: - return self._vocab_size + return self._model_config.vocab_size @property def num_pipeline_stages(self) -> int: - return self._num_pipeline_stages + return self._replica_config.num_pipeline_stages @property def num_layers_per_pipeline_stage(self) -> int: - return self._num_layers // self._num_pipeline_stages + return self._model_config.num_layers // self._replica_config.num_pipeline_stages @property def attention_head_dim(self) -> int: - return self._embedding_dim // self._num_q_heads + return self._model_config.embedding_dim // self._model_config.num_q_heads @property def q_heads_per_tensor_parallel_worker(self) -> int: - return self._num_q_heads // self._num_tensor_parallel_workers + return ( + self._model_config.num_q_heads // self._replica_config.tensor_parallel_size + ) @property def kv_heads_per_tensor_parallel_worker(self) -> int: - return ceil(self._num_kv_heads / self._num_tensor_parallel_workers) + return ceil( + self._model_config.num_kv_heads / self._replica_config.tensor_parallel_size + ) @property def num_tensor_parallel_workers(self) -> int: - return self._num_tensor_parallel_workers + return self._replica_config.tensor_parallel_size @property def total_memory_gb(self) -> int: - return self._total_memory_gb + return self._device_config.total_memory_gb @property def memory_margin_fraction(self) -> float: - return self._memory_margin_fraction + return self._replica_config.memory_margin_fraction @property def max_request_tokens(self) -> int: - return self._max_request_tokens + return self._generator_config.max_tokens @property def per_device_flops(self) -> float: - return self._per_device_flops + return self._device_config.fp16_tflops * 2**40 def to_dict(self) -> dict: return { - "id": self._id, - "num_layers": self._num_layers, - "num_q_heads": self._num_q_heads, - "num_kv_heads": self._num_kv_heads, - "embedding_dim": self._embedding_dim, - "mlp_hidden_dim": self._mlp_hidden_dim, - "use_gated_mlp": self._use_gated_mlp, - "vocab_size": self._vocab_size, - "num_pipeline_stages": self._num_pipeline_stages, - "num_tensor_parallel_workers": self._num_tensor_parallel_workers, + "id": self.id, + "num_layers": self.num_layers, + "num_q_heads": self.num_q_heads, + "num_kv_heads": self.num_kv_heads, + "embedding_dim": self.embedding_dim, + "mlp_hidden_dim": self.mlp_hidden_dim, + "use_gated_mlp": self.use_gated_mlp, + "vocab_size": self.vocab_size, + "num_pipeline_stages": self.num_pipeline_stages, + "num_tensor_parallel_workers": self.num_tensor_parallel_workers, } diff --git a/vidur/execution_time_predictor/base_execution_time_predictor.py b/vidur/execution_time_predictor/base_execution_time_predictor.py index 84a21e29..f399c8ea 100644 --- a/vidur/execution_time_predictor/base_execution_time_predictor.py +++ b/vidur/execution_time_predictor/base_execution_time_predictor.py @@ -1,28 +1,43 @@ from abc import ABC, abstractmethod -from vidur.config import Config +from vidur.config import ( + BaseExecutionTimePredictorConfig, + BaseReplicaSchedulerConfig, + MetricsConfig, + ReplicaConfig, +) from vidur.entities import Batch, ExecutionTime class BaseExecutionTimePredictor(ABC): - def __init__(self, config: Config) -> None: - self._num_tensor_parallel_workers = config.replica_num_tensor_parallel_workers - self._num_pipeline_stages = config.replica_num_pipeline_stages - self._num_layers = config.replica_num_layers + def __init__( + self, + predictor_config: BaseExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: + self._config = predictor_config + self._replica_config = replica_config + self._model_config = replica_config.model_config + + # get configs + self._replica_scheduler_provider = str(replica_scheduler_config.get_type()) + self._block_size = replica_scheduler_config.block_size + self._cache_dir = metrics_config.cache_dir self._num_layers_per_pipeline_stage = ( - config.replica_num_layers // config.replica_num_pipeline_stages + self._model_config.num_layers // self._replica_config.num_pipeline_stages ) - self._replica_scheduler_provider = config.replica_scheduler_provider def get_execution_time(self, batch: Batch, pipeline_stage: int) -> ExecutionTime: - if pipeline_stage == self._num_pipeline_stages - 1: + if pipeline_stage == self._replica_config.num_pipeline_stages - 1: pipeline_parallel_communication_time = 0 else: pipeline_parallel_communication_time = ( self._get_pipeline_parallel_communication_time(batch) ) - if self._num_tensor_parallel_workers == 1: + if self._replica_config.tensor_parallel_size == 1: tensor_parallel_communication_time = 0 else: tensor_parallel_communication_time = ( diff --git a/vidur/execution_time_predictor/dummy_execution_time_predictor.py b/vidur/execution_time_predictor/dummy_execution_time_predictor.py deleted file mode 100644 index 66d1590c..00000000 --- a/vidur/execution_time_predictor/dummy_execution_time_predictor.py +++ /dev/null @@ -1,33 +0,0 @@ -import random -from typing import List - -from vidur.entities import Request -from vidur.execution_time_predictor.base_execution_time_predictor import ( - BaseExecutionTimePredictor, -) - - -class DummyExecutionTimePredictor(BaseExecutionTimePredictor): - def _get_attention_layer_pre_proj_execution_time( - self, batch: List[Request] - ) -> float: - return random.uniform(0.1, 0.2) - - def _get_attention_layer_post_proj_execution_time( - self, batch: List[Request] - ) -> float: - return random.uniform(0.1, 0.2) - - def _get_attention_layer_flash_attention_execution_time( - self, batch: List[Request] - ) -> float: - return random.uniform(0.1, 0.2) - - def _get_mlp_layer_mlp_execution_time(self, batch: List[Request]) -> float: - return random.uniform(0.1, 0.2) - - def _get_tensor_parallel_communication_time(self, batch: List[Request]) -> float: - return random.uniform(0.1, 0.2) - - def _get_pipeline_parallel_communication_time(self, batch: List[Request]) -> float: - return random.uniform(0.1, 0.2) diff --git a/vidur/execution_time_predictor/execution_time_predictor_registry.py b/vidur/execution_time_predictor/execution_time_predictor_registry.py index ed03828b..a48c1105 100644 --- a/vidur/execution_time_predictor/execution_time_predictor_registry.py +++ b/vidur/execution_time_predictor/execution_time_predictor_registry.py @@ -1,6 +1,3 @@ -from vidur.execution_time_predictor.dummy_execution_time_predictor import ( - DummyExecutionTimePredictor, -) from vidur.execution_time_predictor.linear_regression_execution_time_predictor import ( LinearRegressionExecutionTimePredictor, ) @@ -17,9 +14,6 @@ def get_key_from_str(cls, key_str: str) -> ExecutionTimePredictorType: return ExecutionTimePredictorType.from_str(key_str) -ExecutionTimePredictorRegistry.register( - ExecutionTimePredictorType.DUMMY, DummyExecutionTimePredictor -) ExecutionTimePredictorRegistry.register( ExecutionTimePredictorType.RANDOM_FORREST, RandomForrestExecutionTimePredictor ) diff --git a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py index 6506e1ed..8dd32b76 100644 --- a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py +++ b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py @@ -2,35 +2,39 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import PolynomialFeatures +from vidur.config import ( + BaseReplicaSchedulerConfig, + LinearRegressionExecutionTimePredictorConfig, + MetricsConfig, + ReplicaConfig, +) from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, ) class LinearRegressionExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, config): - self._polynomial_degree = ( - config.linear_regression_execution_time_predictor_polynomial_degree - ) - self._polynomial_include_bias = ( - config.linear_regression_execution_time_predictor_polynomial_include_bias - ) - self._polynomial_interaction_only = ( - config.linear_regression_execution_time_predictor_polynomial_interaction_only - ) - self._fit_intercept = ( - config.linear_regression_execution_time_predictor_fit_intercept - ) - + def __init__( + self, + predictor_config: LinearRegressionExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: # will trigger model training - super().__init__(config) + super().__init__( + predictor_config=predictor_config, + replica_config=replica_config, + replica_scheduler_config=replica_scheduler_config, + metrics_config=metrics_config, + ) def _get_grid_search_params(self): return { - "polynomialfeatures__degree": self._polynomial_degree, - "polynomialfeatures__include_bias": self._polynomial_include_bias, - "polynomialfeatures__interaction_only": self._polynomial_interaction_only, - "linearregression__fit_intercept": self._fit_intercept, + "polynomialfeatures__degree": self._config.polynomial_degree, + "polynomialfeatures__include_bias": self._config.polynomial_include_bias, + "polynomialfeatures__interaction_only": self._config.polynomial_interaction_only, + "linearregression__fit_intercept": self._config.fit_intercept, } def _get_estimator(self): diff --git a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py index 0898a0c0..27fd7487 100644 --- a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py +++ b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py @@ -1,28 +1,37 @@ from sklearn.ensemble import RandomForestRegressor +from vidur.config import ( + BaseReplicaSchedulerConfig, + MetricsConfig, + RandomForrestExecutionTimePredictorConfig, + ReplicaConfig, +) from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, ) class RandomForrestExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, config): - self._num_estimators = ( - config.random_forrest_execution_time_predictor_num_estimators - ) - self._max_depth = config.random_forrest_execution_time_predictor_max_depth - self._min_samples_split = ( - config.random_forrest_execution_time_predictor_min_samples_split - ) - + def __init__( + self, + predictor_config: RandomForrestExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: # will trigger model training - super().__init__(config) + super().__init__( + predictor_config=predictor_config, + replica_config=replica_config, + replica_scheduler_config=replica_scheduler_config, + metrics_config=metrics_config, + ) def _get_grid_search_params(self): return { - "n_estimators": self._num_estimators, - "max_depth": self._max_depth, - "min_samples_split": self._min_samples_split, + "n_estimators": self._config.num_estimators, + "max_depth": self._config.max_depth, + "min_samples_split": self._config.min_samples_split, } def _get_estimator(self): diff --git a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py index 852720b5..a5a96466 100644 --- a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py +++ b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py @@ -12,7 +12,12 @@ from sklearn.metrics import make_scorer from sklearn.model_selection import GridSearchCV -from vidur.config import Config +from vidur.config import ( + BaseExecutionTimePredictorConfig, + BaseReplicaSchedulerConfig, + MetricsConfig, + ReplicaConfig, +) from vidur.entities import Batch from vidur.execution_time_predictor.base_execution_time_predictor import ( BaseExecutionTimePredictor, @@ -23,131 +28,105 @@ class SklearnExecutionTimePredictor(BaseExecutionTimePredictor): - def __init__(self, config: Config) -> None: - super().__init__(config) - - self._cache_dir = f"{config.cache_dir}/execution_time_predictor" - os.makedirs(self._cache_dir, exist_ok=True) - - self._no_cache = config.sklearn_execution_time_predictor_no_cache - - self._k_fold_cv_splits = ( - config.sklearn_execution_time_predictor_k_fold_cv_splits + def __init__( + self, + predictor_config: BaseExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: + super().__init__( + predictor_config=predictor_config, + replica_config=replica_config, + replica_scheduler_config=replica_scheduler_config, + metrics_config=metrics_config, ) - self._model_name = config.replica_model_name - self._num_q_heads = config.replica_num_q_heads - self._num_kv_heads = config.replica_num_kv_heads - self._embedding_dim = config.replica_embedding_dim - self._mlp_hidden_dim = config.replica_mlp_hidden_dim - self._use_gated_mlp = config.replica_use_gated_mlp - self._vocab_size = config.replica_vocab_size - self._block_size = config.replica_block_size - self._norm = config.replica_norm - self._post_attn_norm = config.replica_post_attn_norm - - self._model_provider = config.execution_time_predictor_provider + os.makedirs(self._cache_dir, exist_ok=True) # These overheads are only for GQA models self._attention_prefill_batching_overhead_fraction = ( - ( - config.sklearn_execution_time_predictor_attention_prefill_batching_overhead_fraction - ) - if self._num_q_heads > self._num_kv_heads + (self._config.attention_prefill_batching_overhead_fraction) + if self._model_config.num_q_heads > self._model_config.num_kv_heads else 0 ) self._attention_decode_batching_overhead_fraction = ( - ( - config.sklearn_execution_time_predictor_attention_decode_batching_overhead_fraction - ) - if self._num_q_heads > self._num_kv_heads + (self._config.attention_decode_batching_overhead_fraction) + if self._model_config.num_q_heads > self._model_config.num_kv_heads else 0 ) - self._nccl_cpu_launch_overhead_ms = ( - config.sklearn_execution_time_predictor_nccl_cpu_launch_overhead_ms - ) - self._nccl_cpu_skew_overhead_per_device_ms = ( - config.sklearn_execution_time_predictor_nccl_cpu_skew_overhead_per_device_ms - ) - - self._max_batch_size = ( - config.sklearn_execution_time_predictor_prediction_max_batch_size - ) - self._max_tokens_per_request = ( - config.sklearn_execution_time_predictor_prediction_max_tokens_per_request - ) - - if config.replica_scheduler_provider == "orca": - self._max_tokens = self._max_tokens_per_request * self._max_batch_size + if self._replica_scheduler_provider == "orca": + self._max_tokens = ( + self._config.prediction_max_tokens_per_request + * self._config.prediction_max_batch_size + ) else: - self._max_tokens = self._max_tokens_per_request - - self._prefill_chunk_size = config.replica_prefill_chunk_size + self._max_tokens = self._config.prediction_max_tokens_per_request - self._compute_input_file = ( - config.sklearn_execution_time_predictor_compute_input_file + num_workers = ( + self._replica_config.num_pipeline_stages + * self._replica_config.tensor_parallel_size ) - self._attention_input_file = ( - config.sklearn_execution_time_predictor_attention_input_file - ) - self._all_reduce_input_file = ( - config.sklearn_execution_time_predictor_all_reduce_input_file - ) - self._send_recv_input_file = ( - config.sklearn_execution_time_predictor_send_recv_input_file - ) - self._cpu_overhead_input_file = ( - config.sklearn_execution_time_predictor_cpu_overhead_input_file - ) - self._kv_cache_prediction_granularity = ( - config.sklearn_execution_time_predictor_kv_cache_prediction_granularity - ) - self._prediction_max_prefill_chunk_size = ( - config.sklearn_execution_time_predictor_prediction_max_prefill_chunk_size - ) - - self._device_memory = config.replica_total_memory_gb - self._num_training_job_threads = ( - config.sklearn_execution_time_predictor_num_training_job_threads - ) - - devices_per_node = config.replica_num_devices_per_node - num_workers = self._num_pipeline_stages * self._num_tensor_parallel_workers + devices_per_node = self._replica_config.node_config.num_devices_per_node assert ( num_workers < devices_per_node or num_workers % devices_per_node == 0 ), "Number of workers should be less than devices per node or a multiple of devices per node" self._is_multi_node = num_workers > devices_per_node - self._max_batch_tokens = config.vllm_scheduler_max_tokens_in_batch - self._skip_cpu_overhead_modeling = ( - config.sklearn_execution_time_predictor_skip_cpu_overhead_modeling - ) + ( + self._compute_input_file, + self._attention_input_file, + self._all_reduce_input_file, + self._send_recv_input_file, + self._cpu_overhead_input_file, + ) = self._get_input_files() self._models = self._train_models() self._predictions = self._predict_from_models() + def _get_input_files(self) -> Tuple[str, str, str, str, str]: + input_files = [ + self._config.compute_input_file, + self._config.attention_input_file, + self._config.all_reduce_input_file, + self._config.send_recv_input_file, + self._config.cpu_overhead_input_file, + ] + for i in range(len(input_files)): + input_files[i] = ( + input_files[i] + .replace("{DEVICE}", self._replica_config.device) + .replace("{MODEL}", self._model_config.get_name()) + .replace("{NETWORK_DEVICE}", self._replica_config.network_device) + ) + + return tuple(input_files) + def _load_compute_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) df = df.drop_duplicates() logger.debug(f"Length of complete compute df: {len(df)} {file_path}") - logger.debug(f"self._num_q_heads: {self._num_q_heads}") - logger.debug(f"self._embedding_dim: {self._embedding_dim}") - logger.debug(f"self._mlp_hidden_dim: {self._mlp_hidden_dim}") - logger.debug(f"self._use_gated_mlp: {self._use_gated_mlp}") - logger.debug(f"self._vocab_size: {self._vocab_size}") + logger.debug(f"self._num_q_heads: {self._model_config.num_q_heads}") + logger.debug(f"self._embedding_dim: {self._model_config.embedding_dim}") + logger.debug(f"self._mlp_hidden_dim: {self._model_config.mlp_hidden_dim}") + logger.debug(f"self._use_gated_mlp: {self._model_config.use_gated_mlp}") + logger.debug(f"self._vocab_size: {self._model_config.vocab_size}") logger.debug( - f"self._num_tensor_parallel_workers: {self._num_tensor_parallel_workers}" + f"self._num_tensor_parallel_workers: {self._replica_config.tensor_parallel_size}" ) df = df[ - (df["n_head"] == self._num_q_heads) - & (df["n_kv_head"] == self._num_kv_heads) - & (df["n_embd"] == self._embedding_dim) - & (df["n_expanded_embd"] == self._mlp_hidden_dim) - & (df["use_gated_mlp"] == self._use_gated_mlp) - & (df["vocab_size"] == self._vocab_size) - & (df["num_tensor_parallel_workers"] == self._num_tensor_parallel_workers) + (df["n_head"] == self._model_config.num_q_heads) + & (df["n_kv_head"] == self._model_config.num_kv_heads) + & (df["n_embd"] == self._model_config.embedding_dim) + & (df["n_expanded_embd"] == self._model_config.mlp_hidden_dim) + & (df["use_gated_mlp"] == self._model_config.use_gated_mlp) + & (df["vocab_size"] == self._model_config.vocab_size) + & ( + df["num_tensor_parallel_workers"] + == self._replica_config.tensor_parallel_size + ) ] for column in [ @@ -174,18 +153,21 @@ def _load_attention_df(self, file_path: str) -> pd.DataFrame: df.fillna({column: 0}, inplace=True) return df[ - (df["n_embd"] == self._embedding_dim) - & (df["n_q_head"] == self._num_q_heads) - & (df["n_kv_head"] == self._num_kv_heads) + (df["n_embd"] == self._model_config.embedding_dim) + & (df["n_q_head"] == self._model_config.num_q_heads) + & (df["n_kv_head"] == self._model_config.num_kv_heads) & (df["block_size"] == self._block_size) - & (df["num_tensor_parallel_workers"] == self._num_tensor_parallel_workers) + & ( + df["num_tensor_parallel_workers"] + == self._replica_config.tensor_parallel_size + ) ] def _load_all_reduce_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) return df[ - (df["num_workers"] == self._num_tensor_parallel_workers) - & (df["devices_per_node"] == self._num_tensor_parallel_workers) + (df["num_workers"] == self._replica_config.tensor_parallel_size) + & (df["devices_per_node"] == self._replica_config.tensor_parallel_size) & (df["collective"] == "all_reduce") ] @@ -205,8 +187,11 @@ def _load_send_recv_df(self, file_path: str) -> pd.DataFrame: def _load_cpu_overhead_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) filtered_df = df[ - (df["model_name"] == self._model_name) - & (df["tensor_parallel_degree"] == self._num_tensor_parallel_workers) + (df["model_name"] == self._model_config.get_name()) + & ( + df["tensor_parallel_degree"] + == self._replica_config.tensor_parallel_size + ) ] return filtered_df @@ -239,14 +224,14 @@ def _get_all_reduce_df_with_derived_features( # convert bytes to num tokens # each token is of size 2 * h bytes df_with_derived_features["num_tokens"] = ( - df_with_derived_features["size"] / self._embedding_dim / 2 + df_with_derived_features["size"] / self._model_config.embedding_dim / 2 ) return df_with_derived_features def _get_send_recv_df_with_derived_features(self, df: pd.DataFrame) -> pd.DataFrame: df_with_derived_features = df.copy() df_with_derived_features["num_tokens"] = ( - df_with_derived_features["size"] / self._embedding_dim / 2 + df_with_derived_features["size"] / self._model_config.embedding_dim / 2 ) return df_with_derived_features @@ -308,7 +293,7 @@ def _load_model_from_cache(self, model_name: str, model_hash: str) -> BaseEstima with InterProcessReaderWriterLock( f"{self._cache_dir}/{model_hash}_model_lock.file" ).read_lock(): - if self._no_cache: + if self._config.no_cache: return # check if model is in cache cache_file = f"{self._cache_dir}/{model_name}_{model_hash}.pkl" @@ -368,17 +353,17 @@ def _train_model( model = self._get_estimator() grid_search_params = self._get_grid_search_params() - if len(df) < self._k_fold_cv_splits: + if len(df) < self._config.k_fold_cv_splits: cv = 2 else: - cv = self._k_fold_cv_splits + cv = self._config.k_fold_cv_splits grid_search = GridSearchCV( estimator=model, param_grid=grid_search_params, scoring=self._get_scorer(), cv=cv, - n_jobs=self._num_training_job_threads, + n_jobs=self._config.num_training_job_threads, ) # we don't create a train/test split, because we want to use all data for training @@ -422,7 +407,7 @@ def _load_model_predication_cache( with InterProcessReaderWriterLock( f"{self._cache_dir}/{model_hash}_prediction_lock.file" ).read_lock(): - if self._no_cache: + if self._config.no_cache: return cache_file = f"{self._cache_dir}/{model_name}_{model_hash}_predictions.pkl" @@ -506,7 +491,7 @@ def _train_compute_models(self) -> Dict[str, BaseEstimator]: target_col=f"time_stats.{model_name}.median", ) - if self._num_pipeline_stages > 1: + if self._replica_config.num_pipeline_stages > 1: send_recv_df = self._load_send_recv_df(self._send_recv_input_file) send_recv_df = self._get_send_recv_df_with_derived_features(send_recv_df) @@ -517,7 +502,7 @@ def _train_compute_models(self) -> Dict[str, BaseEstimator]: target_col="time_stats.send_recv.median", ) - if self._num_tensor_parallel_workers > 1: + if self._replica_config.tensor_parallel_size > 1: all_reduce_df = self._load_all_reduce_df(self._all_reduce_input_file) all_reduce_df = self._get_all_reduce_df_with_derived_features(all_reduce_df) @@ -531,7 +516,7 @@ def _train_compute_models(self) -> Dict[str, BaseEstimator]: return models def _train_cpu_overhead_models(self) -> Dict[str, BaseEstimator]: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return {} models = {} @@ -616,10 +601,10 @@ def _predict_for_compute_models(self) -> Dict[str, Any]: "add", ] - if self._num_pipeline_stages > 1: + if self._replica_config.num_pipeline_stages > 1: model_names.append("send_recv") - if self._num_tensor_parallel_workers > 1: + if self._replica_config.tensor_parallel_size > 1: model_names.append("all_reduce") num_token_range = np.arange(1, self._max_tokens + 1) @@ -632,7 +617,7 @@ def _predict_for_compute_models(self) -> Dict[str, Any]: return predictions def _predict_for_cpu_overhead_models(self) -> Dict[str, Any]: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return {} predictions = {} @@ -645,7 +630,7 @@ def _predict_for_cpu_overhead_models(self) -> Dict[str, Any]: "ray_comm_time", ] - batch_size_range = np.arange(1, self._max_batch_size + 1) + batch_size_range = np.arange(1, self._config.prediction_max_batch_size + 1) X = pd.DataFrame({"batch_size": batch_size_range}) for model_name in model_names: @@ -657,9 +642,13 @@ def _predict_for_cpu_overhead_models(self) -> Dict[str, Any]: def _predict_for_attention_layer_models(self) -> Dict[str, Any]: predictions = {} - decode_batch_size_range = np.arange(1, self._max_batch_size + 1) + decode_batch_size_range = np.arange( + 1, self._config.prediction_max_batch_size + 1 + ) decode_kv_cache_size_range = np.arange( - 0, self._max_tokens_per_request + 1, self._kv_cache_prediction_granularity + 0, + self._config.prediction_max_tokens_per_request + 1, + self._config.kv_cache_prediction_granularity, ) decode_prefill_chunk_size_range = [0] decode_batch_size, decode_kv_cache_size, decode_prefill_chunk_size = zip( @@ -672,10 +661,12 @@ def _predict_for_attention_layer_models(self) -> Dict[str, Any]: prefill_batch_size_range = [1] prefill_kv_cache_size_range = np.arange( - 0, self._max_tokens_per_request + 1, self._kv_cache_prediction_granularity + 0, + self._config.prediction_max_tokens_per_request + 1, + self._config.kv_cache_prediction_granularity, ) prefill_prefill_chunk_size_range = np.arange( - 1, self._prediction_max_prefill_chunk_size + 1 + 1, self._config.prediction_max_prefill_chunk_size + 1 ) prefill_batch_size, prefill_kv_cache_size, prefill_prefill_chunk_size = zip( *product( @@ -748,9 +739,13 @@ def _get_batch_decode_attention_params(self, batch: Batch) -> Tuple[int, int]: decode_batch_size = len(decode_kv_cache_sizes) decode_avg_kv_cache_size = int(np.mean(decode_kv_cache_sizes)) decode_avg_kv_cache_size = ( - (decode_avg_kv_cache_size + self._kv_cache_prediction_granularity - 1) - // self._kv_cache_prediction_granularity - ) * self._kv_cache_prediction_granularity + ( + decode_avg_kv_cache_size + + self._config.kv_cache_prediction_granularity + - 1 + ) + // self._config.kv_cache_prediction_granularity + ) * self._config.kv_cache_prediction_granularity batch._decode_params = (decode_batch_size, decode_avg_kv_cache_size) @@ -772,11 +767,11 @@ def _get_batch_prefill_attention_params( kv_cache_size = ( ( request.num_processed_tokens - + self._kv_cache_prediction_granularity + + self._config.kv_cache_prediction_granularity - 1 ) - // self._kv_cache_prediction_granularity - ) * self._kv_cache_prediction_granularity + // self._config.kv_cache_prediction_granularity + ) * self._config.kv_cache_prediction_granularity prefill_params.append((kv_cache_size, prefill_chunk_size)) @@ -803,7 +798,7 @@ def _get_attn_norm_layer_act_execution_time(self, batch: Batch) -> float: return self._predictions["input_layernorm"][(batch._total_num_tokens_rounded,)] def _get_mlp_norm_layer_act_execution_time(self, batch: Batch) -> float: - if not self._post_attn_norm: + if not self._model_config.post_attn_norm: return 0 return self._predictions["post_attention_layernorm"][ @@ -816,9 +811,9 @@ def _get_add_layer_act_execution_time(self, batch: Batch) -> float: def _get_tensor_parallel_communication_time(self, batch: Batch) -> float: return ( self._predictions["all_reduce"][(batch._total_num_tokens_rounded,)] - + self._nccl_cpu_launch_overhead_ms - + self._nccl_cpu_skew_overhead_per_device_ms - * self._num_tensor_parallel_workers**1.25 + + self._config.nccl_cpu_launch_overhead_ms + + self._config.nccl_cpu_skew_overhead_per_device_ms + * self._replica_config.tensor_parallel_size**1.25 ) def _get_pipeline_parallel_communication_time(self, batch: Batch) -> float: @@ -874,52 +869,52 @@ def _get_attention_prefill_execution_time(self, batch: Batch) -> float: ) def _get_schedule_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["schedule"][(batch.size,)] def _get_sampler_e2e_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["sampler_e2e"][(batch.size,)] def _get_prepare_inputs_e2e_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["prepare_inputs_e2e"][(batch.size,)] def _get_process_model_outputs_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["process_model_outputs"][(batch.size,)] def _get_ray_comm_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["ray_comm_time"][(batch.size,)] def to_dict(self) -> dict: return { - "model_provider": self._model_provider, - "num_tensor_parallel_workers": self._num_tensor_parallel_workers, - "k_fold_cv_splits": self._k_fold_cv_splits, - "num_q_heads": self._num_q_heads, - "num_kv_heads": self._num_kv_heads, - "embedding_dim": self._embedding_dim, - "mlp_hidden_dim": self._mlp_hidden_dim, - "use_gated_mlp": self._use_gated_mlp, - "vocab_size": self._vocab_size, + "model_provider": str(self._config.get_type()), + "num_tensor_parallel_workers": self._replica_config.tensor_parallel_size, + "k_fold_cv_splits": self._config.k_fold_cv_splits, + "num_q_heads": self._model_config.num_q_heads, + "num_kv_heads": self._model_config.num_kv_heads, + "embedding_dim": self._model_config.embedding_dim, + "mlp_hidden_dim": self._model_config.mlp_hidden_dim, + "use_gated_mlp": self._model_config.use_gated_mlp, + "vocab_size": self._model_config.vocab_size, "block_size": self._block_size, "max_tokens": self._max_tokens, "compute_input_file": self._compute_input_file, "all_reduce_input_file": self._all_reduce_input_file, "send_recv_input_file": self._send_recv_input_file, "cpu_overhead_input_file": self._cpu_overhead_input_file, - "prediction_max_prefill_chunk_size": self._prediction_max_prefill_chunk_size, - "max_batch_size": self._max_batch_size, + "prediction_max_prefill_chunk_size": self._config.prediction_max_prefill_chunk_size, + "max_batch_size": self._config.prediction_max_batch_size, } diff --git a/vidur/main.py b/vidur/main.py index e5def45f..18406fbb 100644 --- a/vidur/main.py +++ b/vidur/main.py @@ -1,10 +1,10 @@ -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.simulator import Simulator from vidur.utils.random import set_seeds -def main(): - config = Config() +def main() -> None: + config: SimulationConfig = SimulationConfig.create_from_cli_args() set_seeds(config.seed) diff --git a/vidur/metrics/cdf_sketch.py b/vidur/metrics/cdf_sketch.py index 3083ce17..50aeebf2 100644 --- a/vidur/metrics/cdf_sketch.py +++ b/vidur/metrics/cdf_sketch.py @@ -1,9 +1,9 @@ import numpy as np import pandas as pd import plotly_express as px +import wandb from ddsketch.ddsketch import DDSketch -import wandb from vidur.logger import init_logger logger = init_logger(__name__) diff --git a/vidur/metrics/data_series.py b/vidur/metrics/data_series.py index 77704967..51be848d 100644 --- a/vidur/metrics/data_series.py +++ b/vidur/metrics/data_series.py @@ -4,8 +4,8 @@ import numpy as np import pandas as pd import plotly_express as px - import wandb + from vidur.logger import init_logger logger = init_logger(__name__) diff --git a/vidur/metrics/metrics_store.py b/vidur/metrics/metrics_store.py index b02bd648..c1e76b0e 100644 --- a/vidur/metrics/metrics_store.py +++ b/vidur/metrics/metrics_store.py @@ -4,9 +4,9 @@ import pandas as pd import plotly_express as px - import wandb -from vidur.config import Config + +from vidur.config import ClusterConfig, MetricsConfig from vidur.entities import Batch, BatchStage, ExecutionTime, Request from vidur.logger import init_logger from vidur.metrics.cdf_sketch import CDFSketch @@ -30,7 +30,7 @@ def if_write_metrics(func): def wrapper(self, *args, **kwargs): - if self._should_write_metrics: + if self._config.write_metrics: return func(self, *args, **kwargs) return wrapper @@ -48,37 +48,14 @@ def wrapper(self, *args, **kwargs): class MetricsStore: - def __init__(self, config: Config): - self._config = config - self._num_replicas = config.cluster_num_replicas - self._num_stages = config.replica_num_pipeline_stages - self._should_write_metrics = config.write_metrics - self._subsamples = config.metrics_store_subsamples - self._save_table_to_wandb = config.metrics_store_save_table_to_wandb - self._save_plots = config.metrics_store_store_plots - self._keep_individual_batch_metrics = ( - config.metrics_store_keep_individual_batch_metrics - ) - - self._wandb_project = config.metrics_store_wandb_project - self._wandb_group = config.metrics_store_wandb_group - self._wandb_run_name = config.metrics_store_wandb_run_name - - self._min_batch_idx = config.metrics_store_min_batch_idx - self._max_batch_idx = config.metrics_store_max_batch_idx + def __init__(self, config: MetricsConfig, cluster_config: ClusterConfig) -> None: + self._config = config self._last_request_arrived_at = None - self._should_store_token_completion_metrics = ( - config.metrics_store_store_token_completion_metrics - ) - self._should_store_utilization_metrics = ( - config.metrics_store_store_utilization_metrics - ) - self._should_store_batch_metrics = config.metrics_store_store_batch_metrics - self._should_store_operation_metrics = ( - config.metrics_store_store_operation_metrics - ) - self._should_store_request_metrics = config.metrics_store_store_request_metrics + + # copy config + self._num_replicas = cluster_config.num_replicas + self._num_pipeline_stages = cluster_config.replica_config.num_pipeline_stages # Initialise request metrics self._request_metrics_time_distributions: Dict[ @@ -88,9 +65,9 @@ def __init__(self, config: Config): self._request_metrics_time_distributions[metric_name] = DataSeries( REQUEST_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._token_metrics_time_distribution: Dict[ @@ -99,8 +76,8 @@ def __init__(self, config: Config): for metric_name in TokenMetricsTimeDistribution: self._token_metrics_time_distribution[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._request_metrics_histogram: Dict[RequestMetricsHistogram, DataSeries] = {} @@ -108,9 +85,9 @@ def __init__(self, config: Config): self._request_metrics_histogram[metric_name] = DataSeries( REQUEST_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) # Initialise batch metrics @@ -123,15 +100,15 @@ def __init__(self, config: Config): for metric_name in BatchMetricsCountDistribution: self._batch_metrics_count_distribution[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._batch_metrics_count_distribution_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._batch_metrics_time_distribution: Dict[ @@ -143,15 +120,15 @@ def __init__(self, config: Config): for metric_name in BatchMetricsTimeDistribution: self._batch_metrics_time_distribution[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._batch_metrics_time_distribution_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) # Initialise completion metrics @@ -162,9 +139,9 @@ def __init__(self, config: Config): self._request_completion_metrics_time_series[metric_name] = DataSeries( TIME_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._token_completion_metrics_time_series: Dict[ TokenCompletionMetricsTimeSeries, DataSeries @@ -173,9 +150,9 @@ def __init__(self, config: Config): self._token_completion_metrics_time_series[metric_name] = DataSeries( TIME_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) # Initialise operation metrics @@ -184,15 +161,15 @@ def __init__(self, config: Config): for metric_name in OperationMetrics: self._operation_metrics[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._operation_metrics_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._cpu_operation_metrics: Dict[CpuOperationMetrics, CDFSketch] = {} @@ -202,15 +179,15 @@ def __init__(self, config: Config): for metric_name in CpuOperationMetrics: self._cpu_operation_metrics[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._cpu_operation_metrics_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) # per replica metrics @@ -218,14 +195,14 @@ def __init__(self, config: Config): # per replica stage metrics self._replica_busy_time = [] self._replica_mfu = [] - self._mfu_calculator = MFUCalculator(config) + self._mfu_calculator = MFUCalculator(cluster_config.replica_config) for replica_idx in range(self._num_replicas): self._replica_memory_usage.append( SeriesAverageMeter( TIME_STR, MEMORY_USAGE_STR, - self._save_table_to_wandb, + self._config.save_table_to_wandb, ) ) self._replica_memory_usage[replica_idx].put(0, 0) @@ -233,12 +210,12 @@ def __init__(self, config: Config): self._replica_busy_time.append([]) self._replica_mfu.append([]) - for stage_idx in range(self._num_stages): + for stage_idx in range(self._num_pipeline_stages): self._replica_busy_time[replica_idx].append( SeriesAverageMeter( TIME_STR, BUSY_TIME_PERCENT, - save_table_to_wandb=self._save_table_to_wandb, + save_table_to_wandb=self._config.save_table_to_wandb, ) ) self._replica_busy_time[replica_idx][stage_idx].put(0, 0) @@ -247,7 +224,7 @@ def __init__(self, config: Config): SeriesAverageMeter( TIME_STR, UTILIZATION_STR, - save_table_to_wandb=self._save_table_to_wandb, + save_table_to_wandb=self._config.save_table_to_wandb, ) ) self._replica_mfu[replica_idx][stage_idx].put(0, 0) @@ -256,16 +233,16 @@ def __init__(self, config: Config): def _init_wandb(self): if ( - not self._should_write_metrics - or not self._wandb_project - or not self._wandb_group + not self._config.write_metrics + or not self._config.wandb_project + or not self._config.wandb_group ): return wandb.init( - project=self._wandb_project, - group=self._wandb_group, - name=self._wandb_run_name, + project=self._config.wandb_project, + group=self._config.wandb_group, + name=self._config.wandb_run_name, config=self._config.to_dict(), ) @@ -283,7 +260,7 @@ def _save_as_csv( [dataseries._to_df() for dataseries in dataseries_list], ) merged_df.to_csv(f"{base_path}/{file_name}.csv", index=False) - if wandb.run and self._save_table_to_wandb: + if wandb.run and self._config.save_table_to_wandb: wand_table = wandb.Table(dataframe=merged_df) wandb.log({f"{file_name}_table": wand_table}, step=0) @@ -311,7 +288,7 @@ def _store_bar_plot( }, step=0, ) - if self._save_plots: + if self._config.store_plots: fig = px.bar( x=list(data.keys()), y=list(data.values()), @@ -320,7 +297,7 @@ def _store_bar_plot( fig.write_image(f"{base_path}/{plot_name}.png") def _store_operation_metrics(self, base_plot_path: str): - if not self._should_store_operation_metrics: + if not self._config.store_operation_metrics: return total_operation_runtimes: Dict[str, float] = {} @@ -347,7 +324,7 @@ def _store_operation_metrics(self, base_plot_path: str): total_operation_runtimes, ) - if not self._keep_individual_batch_metrics: + if not self._config.keep_individual_batch_metrics: return for dataseries in self._operation_metrics_per_batch.values(): @@ -385,7 +362,7 @@ def _store_operation_metrics(self, base_plot_path: str): ) def _store_request_metrics(self, base_plot_path: str): - if not self._should_store_request_metrics: + if not self._config.store_request_metrics: return all_request_metrics = list( @@ -406,7 +383,7 @@ def _store_request_metrics(self, base_plot_path: str): dataseries.plot_cdf(base_plot_path, dataseries._y_name, TIME_STR) def _store_batch_metrics(self, base_plot_path: str): - if not self._should_store_batch_metrics: + if not self._config.store_batch_metrics: return for dataseries in self._batch_metrics_time_distribution.values(): @@ -420,7 +397,7 @@ def _store_batch_metrics(self, base_plot_path: str): for dataseries in self._batch_metrics_count_distribution.values(): dataseries.plot_cdf(base_plot_path, dataseries._metric_name, COUNT_STR) - if not self._keep_individual_batch_metrics: + if not self._config.keep_individual_batch_metrics: return for dataseries in self._batch_metrics_time_distribution_per_batch.values(): @@ -456,13 +433,13 @@ def _store_batch_metrics(self, base_plot_path: str): ) def _store_completion_metrics(self, base_plot_path: str): - if self._should_store_request_metrics: + if self._config.store_request_metrics: for dataseries in self._request_completion_metrics_time_series.values(): dataseries.plot_step( base_plot_path, f"{dataseries._y_name}_time_series", COUNT_STR ) - if not self._should_store_token_completion_metrics: + if not self._config.store_token_completion_metrics: return for dataseries in self._token_metrics_time_distribution.values(): @@ -474,14 +451,14 @@ def _store_completion_metrics(self, base_plot_path: str): ) def _store_utilization_metrics(self, base_plot_path: str): - if not self._should_store_utilization_metrics: + if not self._config.store_utilization_metrics: return for replica_idx in range(self._num_replicas): self._replica_memory_usage[replica_idx].print_stats( f"replica_{replica_idx + 1}_memory_usage", base_plot_path ) - for stage_idx in range(self._num_stages): + for stage_idx in range(self._num_pipeline_stages): self._replica_busy_time[replica_idx][stage_idx].print_stats( f"replica_{replica_idx + 1}_stage_{stage_idx + 1}_busy_time_percent", base_plot_path, @@ -504,7 +481,7 @@ def plot(self) -> None: @if_write_metrics def on_request_arrival(self, time: float, request: Request) -> None: - if not self._should_store_request_metrics: + if not self._config.store_request_metrics: return self._request_completion_metrics_time_series[ @@ -531,7 +508,7 @@ def on_request_arrival(self, time: float, request: Request) -> None: @if_write_metrics def _on_request_end(self, time: float, request: Request) -> None: - if not self._should_store_request_metrics: + if not self._config.store_request_metrics: return self._request_completion_metrics_time_series[ @@ -603,7 +580,7 @@ def _update_per_token_execution_times( # if prefill has just finished in this iteration, update the prefill completion time series if ( time == request.prefill_completed_at - and self._should_store_token_completion_metrics + and self._config.store_token_completion_metrics ): self._token_completion_metrics_time_series[ TokenCompletionMetricsTimeSeries.PREFILL_COMPLETIONS @@ -616,7 +593,7 @@ def _update_per_token_execution_times( if not request.has_started_decode: return - if not self._should_store_token_completion_metrics: + if not self._config.store_token_completion_metrics: return self._token_metrics_time_distribution[ @@ -655,21 +632,21 @@ def _push_metric( def on_batch_end( self, time: float, batch: Batch, replica_id: int, memory_usage_percent: int ) -> None: - if (self._min_batch_idx and batch.id < self._min_batch_idx) or ( - self._max_batch_idx and batch.id > self._max_batch_idx - ): + if ( + self._config.min_batch_index and batch.id < self._config.min_batch_index + ) or (self._config.max_batch_index and batch.id > self._config.max_batch_index): return for request in batch.completed_requests: self._on_request_end(time, request) - if self._should_store_utilization_metrics: + if self._config.store_utilization_metrics: self._replica_memory_usage[replica_id - 1].put(time, memory_usage_percent) for request in batch.requests: self._update_per_token_execution_times(time, request, batch) - if not self._should_store_batch_metrics: + if not self._config.store_batch_metrics: return self._push_metric( @@ -700,7 +677,7 @@ def on_batch_end( def on_replica_schedule( self, time: float, replica_id: int, memory_usage_percent: int ) -> None: - if not self._should_store_utilization_metrics: + if not self._config.store_utilization_metrics: return self._replica_memory_usage[replica_id - 1].put(time, memory_usage_percent) @@ -714,14 +691,14 @@ def on_replica_stage_schedule( batch_stage: BatchStage, execution_time: ExecutionTime, ) -> None: - if not self._should_store_utilization_metrics: + if not self._config.store_utilization_metrics: return self._replica_busy_time[replica_id - 1][stage_id - 1].put(time, 100) mfu = self._mfu_calculator.get_mfu(batch_stage) self._replica_mfu[replica_id - 1][stage_id - 1].put(time, mfu) - if not self._should_store_operation_metrics: + if not self._config.store_operation_metrics: return batch_id = batch_stage._batch_id @@ -833,7 +810,7 @@ def on_replica_stage_schedule( def on_batch_stage_end( self, batch_stage: BatchStage, time: float, replica_id: int, stage_id: int ) -> None: - if not self._should_store_utilization_metrics: + if not self._config.store_utilization_metrics: return self._replica_busy_time[replica_id - 1][stage_id - 1].put(time, 0) self._replica_mfu[replica_id - 1][stage_id - 1].put(time, 0) diff --git a/vidur/metrics/series_average_meter.py b/vidur/metrics/series_average_meter.py index 92b04986..8f6679e1 100644 --- a/vidur/metrics/series_average_meter.py +++ b/vidur/metrics/series_average_meter.py @@ -1,6 +1,7 @@ import json import wandb + from vidur.logger import init_logger logger = init_logger(__name__) diff --git a/vidur/profiling/common/model_config.py b/vidur/profiling/common/model_config.py index 585590d8..7ee132bb 100644 --- a/vidur/profiling/common/model_config.py +++ b/vidur/profiling/common/model_config.py @@ -1,10 +1,11 @@ +from dataclasses import asdict from typing import Any, Dict, Optional import torch -import yaml from sarathi.config import ParallelConfig -from vidur.constants import MODEL_CONFIG_DIR +from vidur.config.model_config import BaseModelConfig +from vidur.types import ActivationType, NormType class ModelConfig: @@ -20,8 +21,8 @@ def __init__( use_gated_mlp: bool, use_bias: bool, use_qkv_bias: bool, - activation: str, - norm: str, + activation: ActivationType, + norm: NormType, post_attn_norm: bool, vocab_size: int, is_neox_style: Optional[bool] = True, @@ -41,8 +42,8 @@ def __init__( self.vocab_size = vocab_size self.use_bias = use_bias self.use_qkv_bias = use_qkv_bias - self.activation = activation - self.norm = norm + self.activation = str(activation) + self.norm = str(norm) self.post_attn_norm = post_attn_norm self.no_tensor_parallel = no_tensor_parallel self.partial_rotary_factor = partial_rotary_factor @@ -60,11 +61,10 @@ def __init__( @staticmethod def from_model_name(model_name: str): - model_config_path = f"{MODEL_CONFIG_DIR}/{model_name}.yml" - with open(model_config_path, "r") as f: - model_config = yaml.safe_load(f) + model_config: BaseModelConfig = BaseModelConfig.create_from_name(model_name) + model_config_dict = asdict(model_config) - return ModelConfig(model_name, **model_config) + return ModelConfig(model_name, **model_config_dict) def get_num_q_heads(self, parallel_config: ParallelConfig): return self.num_q_heads // parallel_config.tensor_parallel_size diff --git a/vidur/request_generator/base_request_generator.py b/vidur/request_generator/base_request_generator.py index e04b34b1..912e7a6e 100644 --- a/vidur/request_generator/base_request_generator.py +++ b/vidur/request_generator/base_request_generator.py @@ -2,20 +2,14 @@ from abc import ABC, abstractmethod from typing import List -from vidur.config import Config +from vidur.config import BaseRequestGeneratorConfig from vidur.entities import Request class BaseRequestGenerator(ABC): - def __init__(self, config: Config): - self._config = config - self._should_write_json_trace = config.write_json_trace - def _write_requests_to_file(self, requests: List[Request]) -> None: - request_dicts = [request.to_dict() for request in requests] - request_file = f"{self._config.output_dir}/requests.json" - with open(request_file, "w") as f: - json.dump(request_dicts, f) + def __init__(self, config: BaseRequestGeneratorConfig): + self.config = config @abstractmethod def generate_requests(self) -> List[Request]: @@ -23,8 +17,4 @@ def generate_requests(self) -> List[Request]: def generate(self) -> List[Request]: requests = self.generate_requests() - - if self._should_write_json_trace: - self._write_requests_to_file(requests) - return requests diff --git a/vidur/request_generator/base_request_interval_generator.py b/vidur/request_generator/base_request_interval_generator.py index ca8a1756..d0370e81 100644 --- a/vidur/request_generator/base_request_interval_generator.py +++ b/vidur/request_generator/base_request_interval_generator.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from vidur.config import Config +from vidur.config import BaseRequestIntervalGeneratorConfig class BaseRequestIntervalGenerator(ABC): - def __init__(self, config: Config): - self._config = config + + def __init__(self, config: BaseRequestIntervalGeneratorConfig): + self.config = config @abstractmethod def get_next_inter_request_time(self) -> float: diff --git a/vidur/request_generator/base_request_length_generator.py b/vidur/request_generator/base_request_length_generator.py index 3d31028d..7162ffc8 100644 --- a/vidur/request_generator/base_request_length_generator.py +++ b/vidur/request_generator/base_request_length_generator.py @@ -1,12 +1,13 @@ from abc import ABC, abstractmethod from typing import Tuple -from vidur.config import Config +from vidur.config import BaseRequestLengthGeneratorConfig class BaseRequestLengthGenerator(ABC): - def __init__(self, config: Config): - self._config = config + + def __init__(self, config: BaseRequestLengthGeneratorConfig): + self.config = config @abstractmethod def get_next_num_tokens(self) -> Tuple[float, float]: diff --git a/vidur/request_generator/fixed_request_length_generator.py b/vidur/request_generator/fixed_request_length_generator.py index 4d52a94a..05eb11e9 100644 --- a/vidur/request_generator/fixed_request_length_generator.py +++ b/vidur/request_generator/fixed_request_length_generator.py @@ -6,8 +6,9 @@ class FixedRequestLengthGenerator(BaseRequestLengthGenerator): + def get_next_num_tokens(self) -> Tuple[float, float]: return ( - self._config.fixed_request_generator_prefill_tokens, - self._config.fixed_request_generator_decode_tokens, + self.config.prefill_tokens, + self.config.decode_tokens, ) diff --git a/vidur/request_generator/gamma_request_interval_generator.py b/vidur/request_generator/gamma_request_interval_generator.py index fc02a508..f85ca130 100644 --- a/vidur/request_generator/gamma_request_interval_generator.py +++ b/vidur/request_generator/gamma_request_interval_generator.py @@ -1,18 +1,20 @@ from scipy.stats import gamma +from vidur.config import GammaRequestIntervalGeneratorConfig from vidur.request_generator.base_request_interval_generator import ( BaseRequestIntervalGenerator, ) class GammaRequestIntervalGenerator(BaseRequestIntervalGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - cv = self._config.gamma_request_interval_generator_cv - self._qps = self._config.gamma_request_interval_generator_qps - self._gamma_shape = 1.0 / (cv**2) + def __init__(self, config: GammaRequestIntervalGeneratorConfig): + super().__init__(config) + + cv = self.config.cv + self.qps = self.config.qps + self.gamma_shape = 1.0 / (cv**2) def get_next_inter_request_time(self) -> float: - gamma_scale = 1.0 / (self._qps * self._gamma_shape) - return gamma.rvs(self._gamma_shape, scale=gamma_scale) + gamma_scale = 1.0 / (self.qps * self.gamma_shape) + return gamma.rvs(self.gamma_shape, scale=gamma_scale) diff --git a/vidur/request_generator/poisson_request_interval_generator.py b/vidur/request_generator/poisson_request_interval_generator.py index 2be7b31d..53a067ce 100644 --- a/vidur/request_generator/poisson_request_interval_generator.py +++ b/vidur/request_generator/poisson_request_interval_generator.py @@ -1,20 +1,23 @@ import math import random +from vidur.config import PoissonRequestIntervalGeneratorConfig from vidur.request_generator.base_request_interval_generator import ( BaseRequestIntervalGenerator, ) class PoissonRequestIntervalGenerator(BaseRequestIntervalGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._qps = self._config.poisson_request_interval_generator_qps - self._std = 1.0 / self._qps - self._max_interval = self._std * 3.0 + def __init__(self, config: PoissonRequestIntervalGeneratorConfig): + super().__init__(config) + + self.qps = self.config.qps + self.std = 1.0 / self.qps + self.max_interval = self.std * 3.0 def get_next_inter_request_time(self) -> float: - next_interval = -math.log(1.0 - random.random()) / self._qps - next_interval = min(next_interval, self._max_interval) + next_interval = -math.log(1.0 - random.random()) / self.qps + next_interval = min(next_interval, self.max_interval) + return next_interval diff --git a/vidur/request_generator/request_generator_registry.py b/vidur/request_generator/request_generator_registry.py index 9e79008e..44c920f1 100644 --- a/vidur/request_generator/request_generator_registry.py +++ b/vidur/request_generator/request_generator_registry.py @@ -9,9 +9,7 @@ class RequestGeneratorRegistry(BaseRegistry): - @classmethod - def get_key_from_str(cls, key_str: str) -> RequestGeneratorType: - return RequestGeneratorType.from_str(key_str) + pass RequestGeneratorRegistry.register( diff --git a/vidur/request_generator/request_interval_generator_registry.py b/vidur/request_generator/request_interval_generator_registry.py index 4d1e5706..42961610 100644 --- a/vidur/request_generator/request_interval_generator_registry.py +++ b/vidur/request_generator/request_interval_generator_registry.py @@ -15,9 +15,7 @@ class RequestIntervalGeneratorRegistry(BaseRegistry): - @classmethod - def get_key_from_str(cls, key_str: str) -> RequestIntervalGeneratorType: - return RequestIntervalGeneratorType.from_str(key_str) + pass RequestIntervalGeneratorRegistry.register( diff --git a/vidur/request_generator/request_length_generator_registry.py b/vidur/request_generator/request_length_generator_registry.py index 7cdec9ad..12775cc5 100644 --- a/vidur/request_generator/request_length_generator_registry.py +++ b/vidur/request_generator/request_length_generator_registry.py @@ -15,9 +15,7 @@ class RequestLengthGeneratorRegistry(BaseRegistry): - @classmethod - def get_key_from_str(cls, key_str: str) -> RequestLengthGeneratorType: - return RequestLengthGeneratorType.from_str(key_str) + pass RequestLengthGeneratorRegistry.register( diff --git a/vidur/request_generator/static_request_interval_generator.py b/vidur/request_generator/static_request_interval_generator.py index 87ad49aa..57eae727 100644 --- a/vidur/request_generator/static_request_interval_generator.py +++ b/vidur/request_generator/static_request_interval_generator.py @@ -4,5 +4,6 @@ class StaticRequestIntervalGenerator(BaseRequestIntervalGenerator): + def get_next_inter_request_time(self) -> float: return 0 diff --git a/vidur/request_generator/synthetic_request_generator.py b/vidur/request_generator/synthetic_request_generator.py index aa8684ce..46ebcb46 100644 --- a/vidur/request_generator/synthetic_request_generator.py +++ b/vidur/request_generator/synthetic_request_generator.py @@ -1,5 +1,6 @@ from typing import List +from vidur.config import SyntheticRequestGeneratorConfig from vidur.entities import Request from vidur.request_generator.base_request_generator import BaseRequestGenerator from vidur.request_generator.request_interval_generator_registry import ( @@ -12,23 +13,22 @@ class SyntheticRequestGenerator(BaseRequestGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._seed = self._config.seed + def __init__(self, config: SyntheticRequestGeneratorConfig): + super().__init__(config) - self._request_length_generator = RequestLengthGeneratorRegistry.get_from_str( - self._config.synthetic_request_generator_length_provider, self._config + self.request_length_generator = RequestLengthGeneratorRegistry.get( + self.config.length_generator_config.get_type(), + self.config.length_generator_config, ) - self._request_interval_generator = ( - RequestIntervalGeneratorRegistry.get_from_str( - self._config.synthetic_request_generator_interval_provider, self._config - ) + self.request_interval_generator = RequestIntervalGeneratorRegistry.get( + self.config.interval_generator_config.get_type(), + self.config.interval_generator_config, ) def _generate_next_request(self, last_arrived_at: float) -> Request: inter_request_time = ( - self._request_interval_generator.get_next_inter_request_time() + self.request_interval_generator.get_next_inter_request_time() ) if inter_request_time is None: return None @@ -37,7 +37,7 @@ def _generate_next_request(self, last_arrived_at: float) -> Request: ( prefill_tokens, decode_tokens, - ) = self._request_length_generator.get_next_num_tokens() + ) = self.request_length_generator.get_next_num_tokens() if prefill_tokens is None or decode_tokens is None: return None @@ -54,18 +54,22 @@ def _generate_requests(self) -> List[Request]: current_time = 0 # first priority is duration - if self._config.synthetic_request_generator_duration is not None: - while current_time < self._config.synthetic_request_generator_duration: + if self.config.duration is not None: + while current_time < self.config.duration: request = self._generate_next_request(current_time) current_time = request.arrived_at requests.append(request) - elif self._config.synthetic_request_generator_num_requests is not None: - for _ in range(self._config.synthetic_request_generator_num_requests): + elif self.config.num_requests is not None: + for _ in range(self.config.num_requests): request = self._generate_next_request(current_time) current_time = request.arrived_at requests.append(request) else: - assert self._config.synthetic_request_generator_interval_provider == "trace" + assert ( + self.config.interval_generator_config.get_type() + == RequestLengthGeneratorRegistry.TRACE + ) + while True: request = self._generate_next_request(current_time) if request is None: @@ -77,24 +81,24 @@ def _generate_requests(self) -> List[Request]: def generate_requests(self) -> List[Request]: assert ( - self._config.synthetic_request_generator_num_requests - or self._config.synthetic_request_generator_duration - or self._config.synthetic_request_generator_interval_provider == "trace" + self.config.num_requests + or self.config.duration + or self.config.interval_generator_config.get_type() + == RequestLengthGeneratorRegistry.TRACE ) - set_seeds(self._seed) + set_seeds(self.config.seed) requests = self._generate_requests() # sort requests by arrival time - requests.sort(key=lambda x: (x.arrived_at, x.id)) + requests.sort(key=lambda x: x.arrived_at) # remove any requests that arrived after the time limit - if self._config.synthetic_request_generator_duration is not None: + if self.config.duration is not None: requests = [ request for request in requests - if request.arrived_at - < self._config.synthetic_request_generator_duration + if request.arrived_at < self.config.duration ] return requests diff --git a/vidur/request_generator/trace_replay_request_generator.py b/vidur/request_generator/trace_replay_request_generator.py index 8283b299..5ef93975 100644 --- a/vidur/request_generator/trace_replay_request_generator.py +++ b/vidur/request_generator/trace_replay_request_generator.py @@ -1,12 +1,13 @@ +import logging from typing import List import pandas as pd +from vidur.config import TraceRequestGeneratorConfig from vidur.entities import Request -from vidur.logger import init_logger from vidur.request_generator.base_request_generator import BaseRequestGenerator -logger = init_logger(__name__) +logger = logging.getLogger(__name__) class TraceReplayRequestGenerator(BaseRequestGenerator): @@ -15,70 +16,62 @@ class TraceReplayRequestGenerator(BaseRequestGenerator): inter-request times, number of tokens. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, config: TraceRequestGeneratorConfig): + super().__init__(config) - self._trace_file = self._config.trace_request_generator_trace_file # load into a pd dataframe - self._trace_df = pd.read_csv(self._trace_file) + self.trace_df = pd.read_csv(config.trace_file) # restrict trace_df to be a subset of rows that have the same date - self._trace_df = self._trace_df[ - self._trace_df["Date"] == self._config.trace_request_generator_date - ] + self.trace_df = self.trace_df[self.trace_df["Date"] == config.date] # scale prefill and decode tokens - self._trace_df["PromptTokenCount"] = ( - self._trace_df["PromptTokenCount"] - * self._config.trace_request_generator_prefill_scale_factor + self.trace_df["PromptTokenCount"] = ( + self.trace_df["PromptTokenCount"] * config.prefill_scale_factor ) - self._trace_df["CompletionTokenCount"] = ( - self._trace_df["CompletionTokenCount"] - * self._config.trace_request_generator_decode_scale_factor + self.trace_df["CompletionTokenCount"] = ( + self.trace_df["CompletionTokenCount"] * config.decode_scale_factor ) # make sure all the prefill and decode counts are integers - self._trace_df["PromptTokenCount"] = self._trace_df["PromptTokenCount"].astype( + self.trace_df["PromptTokenCount"] = self.trace_df["PromptTokenCount"].astype( int ) - self._trace_df["CompletionTokenCount"] = self._trace_df[ + self.trace_df["CompletionTokenCount"] = self.trace_df[ "CompletionTokenCount" ].astype(int) # make sure that there is at least one prefill and decode token - self._trace_df["PromptTokenCount"] = self._trace_df["PromptTokenCount"].clip( + self.trace_df["PromptTokenCount"] = self.trace_df["PromptTokenCount"].clip( lower=1 ) - self._trace_df["CompletionTokenCount"] = self._trace_df[ + self.trace_df["CompletionTokenCount"] = self.trace_df[ "CompletionTokenCount" ].clip(lower=1) # make sure the total does not exceed the max tokens, adjust the prefill tokens if needed total_tokens = ( - self._trace_df["PromptTokenCount"] + self._trace_df["CompletionTokenCount"] + self.trace_df["PromptTokenCount"] + self.trace_df["CompletionTokenCount"] ) - diff_tokens = total_tokens - self._config.request_generator_max_tokens + diff_tokens = total_tokens - config.max_tokens diff_tokens = diff_tokens.clip(lower=0) - self._trace_df["PromptTokenCount"] = ( - self._trace_df["PromptTokenCount"] - diff_tokens + self.trace_df["PromptTokenCount"] = ( + self.trace_df["PromptTokenCount"] - diff_tokens ) assert all( - self._trace_df["PromptTokenCount"] + self._trace_df["CompletionTokenCount"] - <= self._config.request_generator_max_tokens + self.trace_df["PromptTokenCount"] + self.trace_df["CompletionTokenCount"] + <= config.max_tokens ) # rescale the time to change QPS - self._trace_df["Time"] = ( - self._trace_df["Time"] - * self._config.trace_request_generator_time_scale_factor - ) + self.trace_df["Time"] = self.trace_df["Time"] * config.time_scale_factor # compute pd ratio and log the 25, 50, 75, 90, 95, 99 percentiles pd_ratio = ( - self._trace_df["PromptTokenCount"] / self._trace_df["CompletionTokenCount"] + self.trace_df["PromptTokenCount"] / self.trace_df["CompletionTokenCount"] ) logger.info( - f"Loaded trace file {self._trace_file} with {len(self._trace_df)} requests" + f"Loaded trace file {config.trace_file} with {len(self.trace_df)} requests" ) logger.info( f"Prompt/decode token ratio stats\n:{pd_ratio.describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])}" @@ -87,7 +80,7 @@ def __init__(self, *args, **kwargs): def generate_requests(self) -> List[Request]: requests = [] - for _, row in self._trace_df.iterrows(): + for _, row in self.trace_df.iterrows(): request = Request( arrived_at=row["Time"], num_prefill_tokens=row["PromptTokenCount"], diff --git a/vidur/request_generator/trace_request_interval_generator.py b/vidur/request_generator/trace_request_interval_generator.py index 59ddf39c..f5ad0ff3 100644 --- a/vidur/request_generator/trace_request_interval_generator.py +++ b/vidur/request_generator/trace_request_interval_generator.py @@ -1,11 +1,13 @@ +import logging + import pandas as pd -from vidur.logger import init_logger +from vidur.config import TraceRequestIntervalGeneratorConfig from vidur.request_generator.base_request_interval_generator import ( BaseRequestIntervalGenerator, ) -logger = init_logger(__name__) +logger = logging.getLogger(__name__) class TraceRequestIntervalGenerator(BaseRequestIntervalGenerator): @@ -14,52 +16,45 @@ class TraceRequestIntervalGenerator(BaseRequestIntervalGenerator): inter-request times, number of tokens. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, config: TraceRequestIntervalGeneratorConfig): + super().__init__(config) - trace_file = self._config.trace_request_interval_generator_trace_file # load into a pd dataframe - self._trace_df = pd.read_csv(trace_file) + self.trace_df = pd.read_csv(config.trace_file) - self._trace_df["arrival_time"] = pd.to_datetime(self._trace_df["arrival_time"]) + self.trace_df["arrival_time"] = pd.to_datetime(self.trace_df["arrival_time"]) # restrict trace_df to be a subset of rows that have the same date - self._trace_df = self._trace_df[ - ( - self._trace_df["arrival_time"] - > self._config.trace_request_interval_generator_start_time - ) - & ( - self._trace_df["arrival_time"] - < self._config.trace_request_interval_generator_end_time - ) + self.trace_df = self.trace_df[ + (self.trace_df["arrival_time"] > config.start_time) + & (self.trace_df["arrival_time"] < config.end_time) ] # change back to seconds - self._trace_df["arrival_time"] = ( - self._trace_df["arrival_time"] - self._trace_df["arrival_time"].min() + self.trace_df["arrival_time"] = ( + self.trace_df["arrival_time"] - self.trace_df["arrival_time"].min() ) // pd.Timedelta("1s") # rescale the time to change QPS - self._trace_df["arrival_time"] = ( - self._trace_df["arrival_time"] - * self._config.trace_request_interval_generator_time_scale_factor + self.trace_df["arrival_time"] = ( + self.trace_df["arrival_time"] * config.time_scale_factor ) # compute the inter-request time - self._trace_df["inter_request_time"] = self._trace_df["arrival_time"].diff() + self.trace_df["inter_request_time"] = self.trace_df["arrival_time"].diff() - self._next_request_idx = 1 + self.next_request_idx = 1 logger.info( - f"Loaded interval trace file {trace_file} with {len(self._trace_df)} requests" + f"Loaded interval trace file {config.trace_file} with {len(self.trace_df)} requests" ) def get_next_inter_request_time(self) -> float: - if self._next_request_idx >= len(self._trace_df): + if self.next_request_idx >= len(self.trace_df): return None - inter_request_time = self._trace_df.iloc[self._next_request_idx][ + inter_request_time = self.trace_df.iloc[self.next_request_idx][ "inter_request_time" ] - self._next_request_idx += 1 + self.next_request_idx += 1 + return inter_request_time diff --git a/vidur/request_generator/trace_request_length_generator.py b/vidur/request_generator/trace_request_length_generator.py index 69315580..e21c6e93 100644 --- a/vidur/request_generator/trace_request_length_generator.py +++ b/vidur/request_generator/trace_request_length_generator.py @@ -1,106 +1,98 @@ +import logging from typing import Tuple import numpy as np import pandas as pd -from vidur.logger import init_logger +from vidur.config import TraceRequestLengthGeneratorConfig from vidur.request_generator.base_request_length_generator import ( BaseRequestLengthGenerator, ) -logger = init_logger(__name__) +logger = logging.getLogger(__name__) class TraceRequestLengthGenerator(BaseRequestLengthGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - trace_file = self._config.trace_request_length_generator_trace_file - self._trace_df = pd.read_csv(trace_file) + def __init__(self, config: TraceRequestLengthGeneratorConfig): + super().__init__(config) + + self.trace_df = pd.read_csv(config.trace_file) # scale prefill and decode tokens - self._trace_df["num_prefill_tokens"] = ( - self._trace_df["num_prefill_tokens"] - * self._config.trace_request_length_generator_prefill_scale_factor + self.trace_df["num_prefill_tokens"] = ( + self.trace_df["num_prefill_tokens"] * config.prefill_scale_factor ) - self._trace_df["num_decode_tokens"] = ( - self._trace_df["num_decode_tokens"] - * self._config.trace_request_length_generator_decode_scale_factor + self.trace_df["num_decode_tokens"] = ( + self.trace_df["num_decode_tokens"] * config.decode_scale_factor ) # make sure all the prefill and decode counts are integers - self._trace_df["num_prefill_tokens"] = self._trace_df[ + self.trace_df["num_prefill_tokens"] = self.trace_df[ "num_prefill_tokens" ].astype(int) - self._trace_df["num_decode_tokens"] = self._trace_df[ - "num_decode_tokens" - ].astype(int) + self.trace_df["num_decode_tokens"] = self.trace_df["num_decode_tokens"].astype( + int + ) # make sure the total does not exceed the max tokens, adjust the prefill tokens if needed total_tokens = ( - self._trace_df["num_prefill_tokens"] + self._trace_df["num_decode_tokens"] + self.trace_df["num_prefill_tokens"] + self.trace_df["num_decode_tokens"] ) - diff_tokens = total_tokens - self._config.request_generator_max_tokens + diff_tokens = total_tokens - config.max_tokens diff_tokens = diff_tokens.clip(lower=0) - # dedcut the diff tokens from the prefill and decode tokens proportionally - prefill_tokens_ratio = self._trace_df["num_prefill_tokens"] / total_tokens - decode_tokens_ratio = self._trace_df["num_decode_tokens"] / total_tokens + # deduct the diff tokens from the prefill and decode tokens proportionally + prefill_tokens_ratio = self.trace_df["num_prefill_tokens"] / total_tokens + decode_tokens_ratio = self.trace_df["num_decode_tokens"] / total_tokens - self._trace_df["num_prefill_tokens"] -= ( + self.trace_df["num_prefill_tokens"] -= ( np.ceil(diff_tokens * prefill_tokens_ratio) ).astype(int) - self._trace_df["num_decode_tokens"] -= ( + self.trace_df["num_decode_tokens"] -= ( np.ceil(diff_tokens * decode_tokens_ratio) ).astype(int) # make sure that there is at least one prefill and decode token - self._trace_df["num_prefill_tokens"] = self._trace_df[ - "num_prefill_tokens" - ].clip(lower=1) - self._trace_df["num_decode_tokens"] = self._trace_df["num_decode_tokens"].clip( + self.trace_df["num_prefill_tokens"] = self.trace_df["num_prefill_tokens"].clip( + lower=1 + ) + self.trace_df["num_decode_tokens"] = self.trace_df["num_decode_tokens"].clip( lower=1 ) assert all( - self._trace_df["num_prefill_tokens"] + self._trace_df["num_decode_tokens"] - <= self._config.request_generator_max_tokens + self.trace_df["num_prefill_tokens"] + self.trace_df["num_decode_tokens"] + <= self.config.max_tokens ) - assert all(self._trace_df["num_prefill_tokens"] > 0) + assert all(self.trace_df["num_prefill_tokens"] > 0) - assert all(self._trace_df["num_decode_tokens"] > 0) + assert all(self.trace_df["num_decode_tokens"] > 0) # compute pd ratio and log the 25, 50, 75, 90, 95, 99 percentiles pd_ratio = ( - self._trace_df["num_prefill_tokens"] / self._trace_df["num_decode_tokens"] + self.trace_df["num_prefill_tokens"] / self.trace_df["num_decode_tokens"] ) - percentiles = [0.25, 0.5, 0.75, 0.9, 0.95, 0.99] - logger.info( - f"Loaded request length trace file {trace_file} with {len(self._trace_df)} requests" - ) - logger.debug( - f"Prompt token stats\n:{self._trace_df['num_prefill_tokens'].describe(percentiles=percentiles)}" - ) - logger.debug( - f"Decode token stats\n:{self._trace_df['num_decode_tokens'].describe(percentiles=percentiles)}" + f"Loaded request length trace file {config.trace_file} with {len(self.trace_df)} requests" ) - logger.debug( - f"Prompt/decode token ratio stats\n:{pd_ratio.describe(percentiles=percentiles)}" + pd_distribution = pd_ratio.describe( + percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99] ) + logger.debug(f"Prompt/decode token ratio stats\n: {pd_distribution}") # randomly shuffle the df based on the seed - self._trace_df = self._trace_df.sample(frac=1, random_state=self._config.seed) - self._next_request_idx = 0 + self.trace_df = self.trace_df.sample(frac=1, random_state=self.config.seed) + self.next_request_idx = 0 def get_next_num_tokens(self) -> Tuple[float, float]: - if self._next_request_idx >= len(self._trace_df): + if self.next_request_idx >= len(self.trace_df): return None, None - row = self._trace_df.iloc[self._next_request_idx] - self._next_request_idx += 1 + row = self.trace_df.iloc[self.next_request_idx] + self.next_request_idx += 1 return ( row["num_prefill_tokens"], diff --git a/vidur/request_generator/uniform_request_length_generator.py b/vidur/request_generator/uniform_request_length_generator.py index 8ad53686..4f1e6c6b 100644 --- a/vidur/request_generator/uniform_request_length_generator.py +++ b/vidur/request_generator/uniform_request_length_generator.py @@ -8,15 +8,15 @@ class UniformRequestLengthGenerator(BaseRequestLengthGenerator): + def get_next_num_tokens(self) -> Tuple[float, float]: total_tokens = random.uniform( - self._config.synthetic_request_generator_min_tokens, - self._config.request_generator_max_tokens, + self.config.min_tokens, + self.config.max_tokens, ) decode_tokens = math.ceil( - total_tokens - / (1 + self._config.synthetic_request_generator_prefill_to_decode_ratio) + total_tokens / (1 + self.config.prefill_to_decode_ratio) ) prefill_tokens = total_tokens - decode_tokens assert prefill_tokens > 0 and decode_tokens > 0 diff --git a/vidur/request_generator/zipf_request_length_generator.py b/vidur/request_generator/zipf_request_length_generator.py index 43a68ba5..fc08ccde 100644 --- a/vidur/request_generator/zipf_request_length_generator.py +++ b/vidur/request_generator/zipf_request_length_generator.py @@ -1,5 +1,6 @@ from typing import Tuple +from vidur.config import ZipfRequestLengthGeneratorConfig from vidur.request_generator.base_request_length_generator import ( BaseRequestLengthGenerator, ) @@ -7,23 +8,22 @@ class ZipfRequestLengthGenerator(BaseRequestLengthGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._zipf_generator = ZipfGenerator( - self._config.synthetic_request_generator_min_tokens, - self._config.request_generator_max_tokens, - self._config.zipf_request_length_generator_theta, - self._config.zipf_request_length_generator_scramble, - self._config.seed, + + def __init__(self, config: ZipfRequestLengthGeneratorConfig): + super().__init__(config) + + self.zipf_generator = ZipfGenerator( + config.min_tokens, + config.max_tokens, + config.theta, + config.scramble, + config.seed, ) def get_next_num_tokens(self) -> Tuple[float, float]: - total_tokens = self._zipf_generator.next() + total_tokens = self.zipf_generator.next() - decode_tokens = total_tokens / ( - 1 + self._config.synthetic_request_generator_prefill_to_decode_ratio - ) + decode_tokens = total_tokens / (1 + self.config.prefill_to_decode_ratio) prefill_tokens = total_tokens - decode_tokens return prefill_tokens, decode_tokens diff --git a/vidur/scheduler/global_scheduler/base_global_scheduler.py b/vidur/scheduler/global_scheduler/base_global_scheduler.py index be047c9c..493a3e51 100644 --- a/vidur/scheduler/global_scheduler/base_global_scheduler.py +++ b/vidur/scheduler/global_scheduler/base_global_scheduler.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Tuple -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities import Replica, Request from vidur.execution_time_predictor import ExecutionTimePredictorRegistry from vidur.scheduler.replica_scheduler.replica_scheduler_registry import ( @@ -10,23 +10,28 @@ class BaseGlobalScheduler(ABC): - def __init__(self, config: Config, replicas: Dict[int, Replica]): + def __init__(self, config: SimulationConfig, replicas: Dict[int, Replica]): self._config = config self._replicas = replicas self._num_replicas = len(self._replicas) - execution_time_predictor = ExecutionTimePredictorRegistry.get_from_str( - self._config.execution_time_predictor_provider, - self._config, + execution_time_predictor = ExecutionTimePredictorRegistry.get( + config.execution_time_predictor_config.get_type(), + predictor_config=config.execution_time_predictor_config, + replica_config=config.cluster_config.replica_config, + replica_scheduler_config=config.cluster_config.replica_scheduler_config, + metrics_config=config.metrics_config, ) self._replica_schedulers = { - replica_id: ReplicaSchedulerRegistry.get_from_str( - config.replica_scheduler_provider, - config, - replica, - replica.num_pipeline_stages, - execution_time_predictor, + replica_id: ReplicaSchedulerRegistry.get( + config.cluster_config.replica_scheduler_config.get_type(), + replica_config=config.cluster_config.replica_config, + replica_scheduler_config=config.cluster_config.replica_scheduler_config, + request_generator_config=config.request_generator_config, + replica=replica, + num_stages=replica.num_pipeline_stages, + execution_time_predictor=execution_time_predictor, ) for replica_id, replica in replicas.items() } diff --git a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py index a0b2f6c9..2db5c33c 100644 --- a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py @@ -1,7 +1,11 @@ from abc import ABC, abstractmethod from typing import List -from vidur.config import Config +from vidur.config import ( + BaseReplicaSchedulerConfig, + BaseRequestGeneratorConfig, + ReplicaConfig, +) from vidur.entities import Batch, Replica, Request from vidur.execution_time_predictor import BaseExecutionTimePredictor from vidur.logger import init_logger @@ -14,33 +18,32 @@ class BaseReplicaScheduler(ABC): def __init__( self, - config: Config, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + request_generator_config: BaseRequestGeneratorConfig, replica: Replica, num_stages: int, execution_time_predictor: BaseExecutionTimePredictor, ) -> None: - self._config = config + self._config = replica_scheduler_config + self._replica_config = replica_config + self._request_generator_config = request_generator_config self._replica_id = replica.id self._num_stages = num_stages - # store config variables - self._block_size = self._config.replica_block_size - self._max_blocks_per_sequence = ( - self._config.request_generator_max_tokens // self._block_size + self._request_generator_config.max_tokens // self._config.block_size ) - memory_planner = MemoryPlanner(config, replica) - - self._num_total_blocks = config.replica_scheduler_num_blocks + memory_planner = MemoryPlanner(self._replica_config, replica) - if not self._num_total_blocks: - self._num_total_blocks = ( + if not self._config.num_blocks: + self._config.num_blocks = ( self._max_blocks_per_sequence * memory_planner.get_max_request_slots() ) self._max_batch_size = min( memory_planner.get_max_batch_size(), - config.replica_scheduler_batch_size_cap, + self._config.batch_size_cap, ) logger.debug( @@ -75,7 +78,7 @@ def num_allocated_blocks(self) -> int: @property def memory_usage_percent(self) -> int: - return (self._num_allocated_blocks * 100) / self._num_total_blocks + return (self._num_allocated_blocks * 100) / self._config.num_blocks def is_empty(self) -> bool: return ( @@ -102,7 +105,7 @@ def get_replica_stage_scheduler(self, stage_id: int): return self._replica_stage_schedulers[stage_id] def can_allocate(self, num_blocks: int) -> bool: - return self._num_total_blocks - self._num_allocated_blocks >= num_blocks + return self._config.num_blocks - self._num_allocated_blocks >= num_blocks def allocate(self, request_id: int, num_blocks: int) -> None: self._num_allocated_blocks += num_blocks @@ -111,7 +114,7 @@ def allocate(self, request_id: int, num_blocks: int) -> None: else: self._allocation_map[request_id] += num_blocks - assert self._num_allocated_blocks <= self._num_total_blocks + assert self._num_allocated_blocks <= self._config.num_blocks def free(self, *request_ids: List[int]) -> None: for request_id in request_ids: diff --git a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py index 526d58c2..34263bf2 100644 --- a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple import numpy as np @@ -14,14 +14,9 @@ def __init__(self, *args, **kwargs): self._preempted_requests: List[Request] = [] self._num_running_batches = 0 - self._max_tokens_in_batch = self._config.lightllm_scheduler_max_tokens_in_batch - self._max_waiting_iters = self._config.lightllm_scheduler_max_waiting_iters - self._max_batch_size = self._config.replica_scheduler_batch_size_cap - self._max_micro_batch_size = ( - self._config.replica_scheduler_batch_size_cap // self._num_stages - ) + self._max_micro_batch_size = self._config.batch_size_cap // self._num_stages assert ( - self._block_size == 1 + self._config.block_size == 1 ), "LightLLM scheduler only supports block size of 1." assert ( self._num_stages == 1 @@ -39,7 +34,7 @@ def on_batch_end(self, batch: Batch) -> None: else: self._preempted_requests.append(request) - def _get_tuple_tokens(self, request: Request) -> (int, int): + def _get_tuple_tokens(self, request: Request) -> Tuple[int, int]: if request.scheduled: num_processed_tokens = request.num_processed_tokens remaining_tokens = ( @@ -66,7 +61,7 @@ def _can_allocate_request(self, request: Request) -> bool: need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - return need_max_token_num < self._num_total_blocks + return need_max_token_num < self._config.num_blocks def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: @@ -89,10 +84,10 @@ def _get_prefill_batch(self) -> Batch: next_num_tokens = self._get_request_next_num_tokens(request) - if num_batch_tokens + next_num_tokens > self._max_tokens_in_batch: + if num_batch_tokens + next_num_tokens > self._config.max_tokens_in_batch: break - if len(self._allocation_map) == self._max_batch_size: + if len(self._allocation_map) == self._config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: @@ -145,7 +140,7 @@ def _get_next_batch(self) -> Batch: self._num_waiting_iters = 0 return batch - if self._num_waiting_iters >= self._max_waiting_iters: + if self._num_waiting_iters >= self._config.max_waiting_iters: self._num_waiting_iters = 0 batch = self._get_prefill_batch() if batch: diff --git a/vidur/scheduler/replica_scheduler/orca_replica_scheduler.py b/vidur/scheduler/replica_scheduler/orca_replica_scheduler.py index d434f81b..20352bcb 100644 --- a/vidur/scheduler/replica_scheduler/orca_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/orca_replica_scheduler.py @@ -10,9 +10,6 @@ def __init__(self, *args, **kwargs): self._preempted_requests = [] self._num_running_batches = 0 - self._use_single_prefill_per_batch = ( - self._config.orca_scheduler_use_single_prefill_per_batch - ) def on_batch_end(self, batch: Batch) -> None: self._num_running_batches -= 1 @@ -26,7 +23,6 @@ def on_batch_end(self, batch: Batch) -> None: def _get_next_batch(self) -> Batch: requests = [] num_tokens = [] - contains_prefill = False # all preempted_requests will have prefill completed while self._preempted_requests: @@ -45,16 +41,12 @@ def _get_next_batch(self) -> Batch: if not self.can_allocate(self._max_blocks_per_sequence): break - if self._use_single_prefill_per_batch and contains_prefill: - break - request = self._request_queue.pop(0) self.allocate(request.id, self._max_blocks_per_sequence) next_num_tokens = self._get_request_next_num_tokens(request) requests.append(request) num_tokens.append(next_num_tokens) - contains_prefill = True if not requests: return diff --git a/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py b/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py index 85d15fb9..6a9eb9bc 100644 --- a/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py +++ b/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py @@ -18,9 +18,7 @@ class ReplicaSchedulerRegistry(BaseRegistry): - @classmethod - def get_key_from_str(cls, key_str: str) -> ReplicaSchedulerType: - return ReplicaSchedulerType.from_str(key_str) + pass ReplicaSchedulerRegistry.register( diff --git a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py index 637d25e0..c033a188 100644 --- a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py @@ -13,53 +13,39 @@ def __init__(self, *args, **kwargs): # sarathi config self._num_running_batches = 0 self._preempted_requests = [] - self._chunk_size = self._config.sarathi_scheduler_chunk_size - # club multiple prefills to ensure uniform chunk size - self._enable_rolling_prefills = ( - self._config.sarathi_scheduler_enable_rolling_prefills - ) - # when we are packing multiple prefills in a batch, we need to ensure - # that we don't end up packing a very small prefill chunk just to make batch full - # because that will lead to reduced number of schedulable prefill requests - self._prefill_fitting_tolerance = ( - self._config.sarathi_scheduler_prefill_fitting_tolerance - ) - # vLLM config - self._watermark_blocks_fraction = ( - self._config.sarathi_scheduler_watermark_blocks_fraction - ) # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler - self._max_batch_size = self._config.replica_scheduler_batch_size_cap - self._max_micro_batch_size = ( - self._config.replica_scheduler_batch_size_cap // self._num_stages - ) + self._max_micro_batch_size = self._config.batch_size_cap // self._num_stages self._watermark_blocks = int( - self._watermark_blocks_fraction * self._num_total_blocks + self._config.watermark_blocks_fraction * self._config.num_blocks ) def _can_allocate_request(self, request: Request) -> bool: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil(request.num_prefill_tokens / self._block_size) + num_required_blocks = ceil( + request.num_prefill_tokens / self._config.block_size + ) return ( - self._num_total_blocks + self._config.num_blocks - self._num_allocated_blocks - num_required_blocks >= self._watermark_blocks ) # vllm requires at least one block to be available - return self._num_total_blocks - self._num_allocated_blocks >= 1 + return self._config.num_blocks - self._num_allocated_blocks >= 1 def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil(request.num_prefill_tokens / self._block_size) + num_required_blocks = ceil( + request.num_prefill_tokens / self._config.block_size + ) self.allocate(request.id, num_required_blocks) return - num_tokens_reserved = self._allocation_map[request.id] * self._block_size + num_tokens_reserved = self._allocation_map[request.id] * self._config.block_size num_tokens_required = max(0, request.num_processed_tokens - num_tokens_reserved) assert ( @@ -90,22 +76,12 @@ def _get_request_next_num_tokens( next_num_tokens = min( request.num_prefill_tokens - request.num_processed_tokens, - self._chunk_size - num_batch_tokens, + self._config.chunk_size - num_batch_tokens, ) - if not batch_contains_prefill: - return next_num_tokens + next_num_tokens = max(0, next_num_tokens) - if self._enable_rolling_prefills and num_batch_tokens < self._chunk_size * ( - 1 - self._prefill_fitting_tolerance - ): - # we can have multiple prefills per batch - # but the total number of tokens should not exceed - # the max batch size - return next_num_tokens - else: - # we will only allow one prefill per batch - return 0 + return next_num_tokens def _get_next_batch(self) -> Batch: requests = [] @@ -178,7 +154,7 @@ def _get_next_batch(self) -> Batch: skipped_requests = [] while self._request_queue: - if len(self._allocation_map) == self._max_batch_size: + if len(self._allocation_map) == self._config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: diff --git a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py index c84f33b8..eaa871c8 100644 --- a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py @@ -13,18 +13,11 @@ def __init__(self, *args, **kwargs): self._preempted_requests: List[Request] = [] self._num_running_batches = 0 - self._watermark_blocks_fraction = ( - self._config.vllm_scheduler_watermark_blocks_fraction - ) - self._max_tokens_in_batch = self._config.vllm_scheduler_max_tokens_in_batch # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler - self._max_batch_size = self._config.replica_scheduler_batch_size_cap - self._max_micro_batch_size = ( - self._config.replica_scheduler_batch_size_cap // self._num_stages - ) + self._max_micro_batch_size = self._config.batch_size_cap // self._num_stages self._watermark_blocks = int( - self._watermark_blocks_fraction * self._num_total_blocks + self._config.watermark_blocks_fraction * self._config.num_blocks ) def on_batch_end(self, batch: Batch) -> None: @@ -39,25 +32,29 @@ def on_batch_end(self, batch: Batch) -> None: def _can_allocate_request(self, request: Request) -> bool: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil((request.num_prefill_tokens) / self._block_size) + num_required_blocks = ceil( + (request.num_prefill_tokens) / self._config.block_size + ) return ( - self._num_total_blocks + self._config.num_blocks - self._num_allocated_blocks - num_required_blocks >= self._watermark_blocks ) # vllm requires at least one block to be available - return self._num_total_blocks - self._num_allocated_blocks >= 1 + return self._config.num_blocks - self._num_allocated_blocks >= 1 def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil((request.num_prefill_tokens) / self._block_size) + num_required_blocks = ceil( + (request.num_prefill_tokens) / self._config.block_size + ) self.allocate(request.id, num_required_blocks) return - num_tokens_reserved = self._allocation_map[request.id] * self._block_size + num_tokens_reserved = self._allocation_map[request.id] * self._config.block_size num_tokens_required = max(0, request.num_processed_tokens - num_tokens_reserved) assert ( num_tokens_required == 0 or num_tokens_required == 1 @@ -83,10 +80,10 @@ def _get_next_batch(self) -> Batch: new_num_tokens = num_tokens + [next_num_tokens] new_num_batch_tokens = len(new_num_tokens) * max(new_num_tokens) - if new_num_batch_tokens > self._max_tokens_in_batch: + if new_num_batch_tokens > self._config.max_tokens_in_batch: break - if len(self._allocation_map) == self._max_batch_size: + if len(self._allocation_map) == self._config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: diff --git a/vidur/scheduler/utils/memory_planner.py b/vidur/scheduler/utils/memory_planner.py index 874deaf3..e769e7b1 100644 --- a/vidur/scheduler/utils/memory_planner.py +++ b/vidur/scheduler/utils/memory_planner.py @@ -1,11 +1,11 @@ -from vidur.config import Config +from vidur.config import ReplicaConfig from vidur.entities.replica import Replica from vidur.utils.param_counter import ParamCounter class MemoryPlanner: - def __init__(self, config: Config, replica: Replica) -> None: - self._param_counter = ParamCounter(config) + def __init__(self, replica_config: ReplicaConfig, replica: Replica) -> None: + self._param_counter = ParamCounter(replica_config) self._replica = replica def _get_kv_cache_memory_per_layer_per_request(self) -> int: diff --git a/vidur/simulator.py b/vidur/simulator.py index 02d77adb..f20c7f89 100644 --- a/vidur/simulator.py +++ b/vidur/simulator.py @@ -3,7 +3,7 @@ import json from typing import List -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities import Cluster from vidur.events import BaseEvent, RequestArrivalEvent from vidur.logger import init_logger @@ -15,30 +15,36 @@ class Simulator: - def __init__(self, config: Config) -> None: - self._config = config + def __init__(self, config: SimulationConfig) -> None: + self._config: SimulationConfig = config self._time = 0 self._terminate = False - self._time_limit = self._config.simulator_time_limit + self._time_limit = self._config.time_limit if not self._time_limit: self._time_limit = float("inf") self._event_queue = [] - self._should_write_json_trace = self._config.write_json_trace - self._should_write_chrome_trace = self._config.write_chrome_trace - self._event_trace = [] self._event_chrome_trace = [] - self._cluster = Cluster(self._config) - self._metric_store = MetricsStore(self._config) - self._request_generator = RequestGeneratorRegistry.get_from_str( - self._config.request_generator_provider, self._config + self._cluster = Cluster( + self._config.cluster_config, + self._config.metrics_config, + self._config.request_generator_config, + ) + self._metric_store = MetricsStore( + self._config.metrics_config, self._config.cluster_config + ) + self._request_generator = RequestGeneratorRegistry.get( + self._config.request_generator_config.get_type(), + self._config.request_generator_config, ) - self._scheduler = GlobalSchedulerRegistry.get_from_str( - self._config.global_scheduler_provider, self._config, self._cluster.replicas + self._scheduler = GlobalSchedulerRegistry.get( + self._config.cluster_config.global_scheduler_config.get_type(), + self._config, + self._cluster.replicas, ) self._init_event_queue() @@ -54,7 +60,7 @@ def metric_store(self) -> MetricsStore: def run(self) -> None: logger.info( - f"Starting simulation with cluster: {self._cluster} and {len(self._event_queue) - 1} requests" + f"Starting simulation with cluster: {self._cluster} and {len(self._event_queue)} requests" ) while self._event_queue and not self._terminate: @@ -63,10 +69,10 @@ def run(self) -> None: new_events = event.handle_event(self._scheduler, self._metric_store) self._add_events(new_events) - if self._should_write_json_trace: + if self._config.metrics_config.write_json_trace: self._event_trace.append(event.to_dict()) - if self._should_write_chrome_trace: + if self._config.metrics_config.enable_chrome_trace: chrome_trace = event.to_chrome_trace() if chrome_trace: self._event_chrome_trace.append(chrome_trace) @@ -81,12 +87,12 @@ def _write_output(self) -> None: self._metric_store.plot() logger.info("Metrics written") - if self._should_write_json_trace: + if self._config.metrics_config.write_json_trace: self._write_event_trace() self._scheduler.write_batching_history() logger.info("Json event trace written") - if self._should_write_chrome_trace: + if self._config.metrics_config.enable_chrome_trace: self._write_chrome_trace() logger.info("Chrome event trace written") @@ -112,12 +118,12 @@ def _set_time(self, time: float) -> None: self._terminate = True def _write_event_trace(self) -> None: - trace_file = f"{self._config.output_dir}/event_trace.json" + trace_file = f"{self._config.metrics_config.output_dir}/event_trace.json" with open(trace_file, "w") as f: json.dump(self._event_trace, f) def _write_chrome_trace(self) -> None: - trace_file = f"{self._config.output_dir}/chrome_trace.json" + trace_file = f"{self._config.metrics_config.output_dir}/chrome_trace.json" chrome_trace = {"traceEvents": self._event_chrome_trace} diff --git a/vidur/types/__init__.py b/vidur/types/__init__.py index 5e64c3ec..da67d952 100644 --- a/vidur/types/__init__.py +++ b/vidur/types/__init__.py @@ -1,7 +1,11 @@ +from vidur.types.activation_type import ActivationType from vidur.types.base_int_enum import BaseIntEnum +from vidur.types.device_sku_type import DeviceSKUType from vidur.types.event_type import EventType from vidur.types.execution_time_predictor_type import ExecutionTimePredictorType from vidur.types.global_scheduler_type import GlobalSchedulerType +from vidur.types.node_sku_type import NodeSKUType +from vidur.types.norm_type import NormType from vidur.types.replica_scheduler_type import ReplicaSchedulerType from vidur.types.request_generator_type import RequestGeneratorType from vidur.types.request_interval_generator_type import RequestIntervalGeneratorType @@ -15,5 +19,9 @@ RequestLengthGeneratorType, RequestIntervalGeneratorType, ReplicaSchedulerType, + DeviceSKUType, + NodeSKUType, + NormType, + ActivationType, BaseIntEnum, ] diff --git a/vidur/types/activation_type.py b/vidur/types/activation_type.py new file mode 100644 index 00000000..d622f2f2 --- /dev/null +++ b/vidur/types/activation_type.py @@ -0,0 +1,6 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class ActivationType(BaseIntEnum): + GELU = 0 + SILU = 1 diff --git a/vidur/types/device_sku_type.py b/vidur/types/device_sku_type.py new file mode 100644 index 00000000..9077efea --- /dev/null +++ b/vidur/types/device_sku_type.py @@ -0,0 +1,7 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class DeviceSKUType(BaseIntEnum): + A40 = 1 + A100 = 2 + H100 = 3 diff --git a/vidur/types/node_sku_type.py b/vidur/types/node_sku_type.py new file mode 100644 index 00000000..4bcaabf4 --- /dev/null +++ b/vidur/types/node_sku_type.py @@ -0,0 +1,9 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class NodeSKUType(BaseIntEnum): + A40_PAIRWISE_NVLINK = 1 + A100_PAIRWISE_NVLINK = 2 + H100_PAIRWISE_NVLINK = 3 + A100_DGX = 4 + H100_DGX = 5 diff --git a/vidur/types/norm_type.py b/vidur/types/norm_type.py new file mode 100644 index 00000000..b5a783d5 --- /dev/null +++ b/vidur/types/norm_type.py @@ -0,0 +1,6 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class NormType(BaseIntEnum): + LAYER_NORM = 0 + RMS_NORM = 1 diff --git a/vidur/utils/mfu_calculator.py b/vidur/utils/mfu_calculator.py index 482385d4..fecab53a 100644 --- a/vidur/utils/mfu_calculator.py +++ b/vidur/utils/mfu_calculator.py @@ -1,21 +1,24 @@ -from vidur.config import Config +from vidur.config import ReplicaConfig from vidur.entities import BatchStage from vidur.utils.param_counter import ParamCounter class MFUCalculator: - def __init__(self, config: Config): - param_counter = ParamCounter(config) + + def __init__(self, replica_config: ReplicaConfig): + param_counter = ParamCounter(replica_config) self._num_params_per_device = param_counter.get_num_parameters_per_device() + + model_config = replica_config.model_config + self._num_layers_per_device = ( - config.replica_num_layers // config.replica_num_pipeline_stages + model_config.num_layers // replica_config.num_pipeline_stages ) - self._embedding_dim = config.replica_embedding_dim self._num_heads_per_device = ( - config.replica_num_q_heads // config.replica_num_tensor_parallel_workers + model_config.num_q_heads // replica_config.tensor_parallel_size ) - self._head_dimension = self._embedding_dim // config.replica_num_q_heads - self._device_flops = config.replica_fp16_tflops * 2**40 + self._head_dimension = model_config.embedding_dim // model_config.num_q_heads + self._device_flops = replica_config.device_config.fp16_tflops * 2**40 def _get_mlp_flops(self, batch_stage: BatchStage) -> float: num_tokens = sum(batch_stage.num_tokens) diff --git a/vidur/utils/param_counter.py b/vidur/utils/param_counter.py index 1ca5b28a..5ef348f0 100644 --- a/vidur/utils/param_counter.py +++ b/vidur/utils/param_counter.py @@ -1,42 +1,45 @@ from math import ceil -from vidur.config import Config +from vidur.config import ReplicaConfig class ParamCounter: - def __init__(self, config: Config) -> None: - self._embedding_dim = config.replica_embedding_dim - self._num_pipeline_stages = config.replica_num_pipeline_stages - self._num_tensor_parallel_workers = config.replica_num_tensor_parallel_workers - self._num_layers = config.replica_num_layers - self._num_q_heads = config.replica_num_q_heads - self._num_kv_heads = config.replica_num_kv_heads - self._embedding_dim = config.replica_embedding_dim - self._mlp_hidden_dim = config.replica_mlp_hidden_dim - self._use_gated_mlp = config.replica_use_gated_mlp - self._vocab_size = config.replica_vocab_size + def __init__(self, replica_config: ReplicaConfig) -> None: + self._replica_config = replica_config + self._model_config = self._replica_config.model_config - assert self._num_q_heads % self._num_tensor_parallel_workers == 0 - assert self._num_layers % self._num_pipeline_stages == 0 - assert self._embedding_dim % self._num_tensor_parallel_workers == 0 - assert self._embedding_dim % self._num_q_heads == 0 + assert ( + self._model_config.num_q_heads % self._replica_config.tensor_parallel_size + == 0 + ) + assert ( + self._model_config.num_layers % self._replica_config.num_pipeline_stages + == 0 + ) + assert ( + self._model_config.embedding_dim % self._replica_config.tensor_parallel_size + == 0 + ) + assert self._model_config.embedding_dim % self._model_config.num_q_heads == 0 self._num_layers_per_pipeline_stage = ( - self._num_layers // self._num_pipeline_stages + self._model_config.num_layers // self._replica_config.num_pipeline_stages + ) + self._attention_head_dim = ( + self._model_config.embedding_dim // self._model_config.num_q_heads ) - self._attention_head_dim = self._embedding_dim // self._num_q_heads self._q_heads_per_tensor_parallel_worker = ( - self._num_q_heads // self._num_tensor_parallel_workers + self._model_config.num_q_heads // self._replica_config.tensor_parallel_size ) self._kv_heads_per_tensor_parallel_worker = ceil( - self._num_kv_heads / self._num_tensor_parallel_workers + self._model_config.num_kv_heads / self._replica_config.tensor_parallel_size ) def get_num_parameters_per_layer(self) -> int: num_parameters = 0 # weights for attention metrics Wq, Wk, Wv num_parameters += ( - self._embedding_dim + self._model_config.embedding_dim * self._attention_head_dim * ( self._q_heads_per_tensor_parallel_worker @@ -45,24 +48,24 @@ def get_num_parameters_per_layer(self) -> int: ) # weights for attention metrics Wo num_parameters += ( - self._embedding_dim + self._model_config.embedding_dim * self._attention_head_dim * self._q_heads_per_tensor_parallel_worker ) # fc layer weights - if self._use_gated_mlp: + if self._model_config.use_gated_mlp: num_parameters += ( 3 - * self._embedding_dim - * self._mlp_hidden_dim - // self._num_tensor_parallel_workers + * self._model_config.embedding_dim + * self._model_config.mlp_hidden_dim + // self._replica_config.tensor_parallel_size ) else: num_parameters += ( 2 - * self._embedding_dim - * self._mlp_hidden_dim - // self._num_tensor_parallel_workers + * self._model_config.embedding_dim + * self._model_config.mlp_hidden_dim + // self._replica_config.tensor_parallel_size ) return num_parameters