Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Profiling][Model][Doc] Support Llama3-8B and 70B on A100s #22

Merged
merged 6 commits into from
Jul 24, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Merged PR 1873: Support Llama3 8B and 70B for 32k context length on a…
…100_pairwise_nvlink

# Changelog

* Support Llama3 8B and 70B https://llama.meta.com/llama3/
* Max supported context length is 32k, only on 4xA100.
* Pipeline parallel is not profiled yet for more than 4k.
* Attention profiling enhancements:
** Reduce number of input combinations by removing those batches which require more kv cache blocks than available GPU memory.
nitinkedia7 committed Jul 24, 2024
commit 1478250bb8be249631e2b7353948faf662e28665
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -92,12 +92,15 @@ The metrics will be logged to wandb directly and a copy will be stored in the `s

| Model / Device | A100 80GB DGX | H100 DGX | 4xA100 80GB Pairwise NVLink Node | 8xA40 Pairwise NVLink Node |
| --- | --- | --- | --- | --- |
| `meta-llama/Meta-Llama-3-8B` | ✅ | ❌ | ✅* | ❌ |
| `meta-llama/Meta-Llama-3-70B` | ✅ | ❌ | ✅ | ❌ |
| `meta-llama/Llama-2-7b-hf` | ✅ | ✅ | ✅ | ✅ |
| `codellama/CodeLlama-34b-Instruct-hf"` | ✅ | ✅ | ✅ | ✅ |
| `meta-llama/Llama-2-70b-hf` | ✅ | ✅ | ✅ | ✅ |
| `internlm/internlm-20b` | ✅ | ✅ | ✅ | ✅ |
| `Qwen/Qwen-72B` | ✅ | ✅ | ✅ | ✅ |

* Maximum context length supported is 4k except `Llama3-8B` and `Llama3-70B` which support 32k context length on 4xA100 80GB Pairwise NVLink Node.
* Pipeline parallelism is supported for all models. The PP dimension should divide the number of layers in the model.
* In DGX nodes, there are 8 GPUs, fully connected via NVLink. So TP1, TP2, TP4 and TP8 are supported.
* In 4x pairwise NVLink nodes, there are 4 GPUs, so TP1, TP2 and TP4 are supported. TP4 here is less performant than TP4 in DGX nodes because (GPU1, GPU2) are connected via NVLink and (GPU3, GPU4) are connected via NVLink. but between these layers, the interconnect is slower.
16 changes: 16 additions & 0 deletions data/model_configs/meta-llama/Meta-Llama-3-70B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
num_layers: 80
num_q_heads: 64
num_kv_heads: 8
embedding_dim: 8192
mlp_hidden_dim: 28672
max_position_embeddings: 8192
use_gated_mlp: true
use_bias: false
use_qkv_bias: false
activation: silu
norm: rms_norm
post_attn_norm: true
rope_theta: 500000.0
rope_scaling: null
vocab_size: 128256
is_neox_style: true
16 changes: 16 additions & 0 deletions data/model_configs/meta-llama/Meta-Llama-3-8B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
num_layers: 32
num_q_heads: 32
num_kv_heads: 8
embedding_dim: 4096
mlp_hidden_dim: 14336
max_position_embeddings: 4096
use_gated_mlp: true
use_bias: false
use_qkv_bias: false
activation: silu
norm: rms_norm
post_attn_norm: true
rope_theta: 500000.0
rope_scaling: null
vocab_size: 128256
is_neox_style: true
68,617 changes: 68,617 additions & 0 deletions data/profiling/compute/a100/meta-llama/Meta-Llama-3-70B/attention.csv

Large diffs are not rendered by default.

1,825 changes: 1,825 additions & 0 deletions data/profiling/compute/a100/meta-llama/Meta-Llama-3-70B/mlp.csv

Large diffs are not rendered by default.

91,513 changes: 91,513 additions & 0 deletions data/profiling/compute/a100/meta-llama/Meta-Llama-3-8B/attention.csv

Large diffs are not rendered by default.

1,825 changes: 1,825 additions & 0 deletions data/profiling/compute/a100/meta-llama/Meta-Llama-3-8B/mlp.csv

Large diffs are not rendered by default.

25,492 changes: 4,496 additions & 20,996 deletions data/profiling/network/a100_pair_nvlink/all_reduce.csv

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions vidur/profiling/attention/attention_input.py
Original file line number Diff line number Diff line change
@@ -28,5 +28,8 @@ def is_valid(self, max_seq_len: int):
return False
return True

def is_under_memory_limit(self, max_num_tokens: int):
return self.batch_size * (self.kv_cache_size + self.prefill_chunk_size) <= max_num_tokens

def __str__(self):
return f"prefill_chunk_size: {self.prefill_chunk_size}, kv_cache_size: {self.kv_cache_size}, batch_size: {self.batch_size}, is_prefill: {self.is_prefill}"
71 changes: 31 additions & 40 deletions vidur/profiling/attention/attention_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import ceil, floor
from math import ceil
from typing import List

import numpy as np
@@ -10,17 +10,19 @@
# monkey patching the CudaTimer class to use the sarathi implementation
sarathi.metrics.cuda_timer.CudaTimer = CudaTimer

from sarathi.config import ParallelConfig
from sarathi.model_executor.attention import (
AttentionBackend,
get_attention_wrapper,
set_attention_backend,
)

from vidur.profiling.attention.attention_input import AttentionInput
from vidur.profiling.attention.sequence_proxy import SequenceMetadataProxy
from vidur.profiling.attention.sequence_proxy import (
SequenceMetadataProxy,
)
from vidur.profiling.common.model_config import ModelConfig
from vidur.profiling.common.timer_stats_store import TimerStatsStore
from vidur.profiling.utils import ProfileMethod

WARMUP_STEPS = 2
ACTIVE_STEPS = 5
@@ -30,55 +32,44 @@ class AttentionWrapper:
def __init__(
self,
model_config: ModelConfig,
num_tensor_parallel_workers: int,
parallel_config: ParallelConfig,
max_num_blocks: int,
max_model_len: int,
block_size: int,
attention_backend: AttentionBackend,
dtype: torch.dtype,
):
self.time_stats_store = TimerStatsStore(profile_method="kineto")

self._n_embd = model_config.embedding_dim
self._n_q_head = model_config.num_q_heads
self._n_kv_head = model_config.num_kv_heads
self._num_tensor_parallel_workers = num_tensor_parallel_workers
assert self._n_embd % self._n_q_head == 0
self._head_dim = self._n_embd // self._n_q_head
self._max_model_len = max_model_len
self._block_size = block_size
self._model_config = model_config
self._parallel_config = parallel_config
self._dtype = dtype
self._device = torch.device("cuda")

assert self._n_q_head % num_tensor_parallel_workers == 0
self._n_worker_q_heads = self._n_q_head // num_tensor_parallel_workers
assert self._n_kv_head % num_tensor_parallel_workers == 0
self._n_worker_kv_heads = self._n_kv_head // num_tensor_parallel_workers
self._max_model_len = max_model_len
self._n_worker_q_heads = self._model_config.get_num_q_heads(
self._parallel_config
)
self._n_worker_kv_heads = self._model_config.get_num_kv_heads(
self._parallel_config
)
self._head_dim = self._model_config.get_head_size()

self._dtype = torch.float16
self._device = torch.device("cuda")
self._block_size = block_size

self._attention_backend = attention_backend
set_attention_backend(attention_backend)
get_attention_wrapper().init(
self._n_worker_q_heads,
self._n_worker_kv_heads,
self._head_dim,
self._model_config,
self._parallel_config,
self._block_size,
self._device,
)
self._max_blocks_per_sequence = ceil(max_model_len / self._block_size)
# We create (big) KV tensors and reuse them
element_size = torch.randn(1, dtype=self._dtype).element_size()
block_memory_size = (
2
* self._block_size
* self._n_worker_kv_heads
* self._head_dim
* element_size
)
self.total_num_blocks = floor(
(torch.cuda.mem_get_info()[1] * 0.9)
/ (block_memory_size * model_config.num_layers)
)
self.max_num_blocks = max_num_blocks
self.kv_cache = get_attention_wrapper().get_cache_block(
self.total_num_blocks, dtype=self._dtype, device=self._device
self.max_num_blocks, dtype=self._dtype, device=self._device
)

def _get_input_tensors(
@@ -113,13 +104,13 @@ def _get_input_tensors(
num_blocks = ceil(
(num_tokens_per_seq + attention_input.kv_cache_size) / self._block_size
)
# TODO(nitinkedia7): Investigate why high=total_num_blocks fails with a CUDA illegal memory access
# TODO(nitinkedia7): Investigate why high=max_num_blocks fails with a CUDA illegal memory access
seq_metadata = SequenceMetadataProxy(
is_prompt=attention_input.is_prefill,
total_len=num_tokens_per_seq + attention_input.kv_cache_size,
processed_len=attention_input.kv_cache_size,
block_table=np.random.default_rng()
.integers(low=0, high=self.total_num_blocks - 1, size=num_blocks)
.integers(low=0, high=self.max_num_blocks - 1, size=num_blocks)
.tolist(),
)
seq_metadata_list.append(seq_metadata)
@@ -152,11 +143,11 @@ def profile(

return {
"time_stats": self.time_stats_store.get_stats(),
"n_embd": self._n_embd,
"n_q_head": self._n_q_head,
"n_kv_head": self._n_kv_head,
"n_embd": self._model_config.embedding_dim,
"n_q_head": self._model_config.num_q_heads,
"n_kv_head": self._model_config.num_kv_heads,
"block_size": self._block_size,
"num_tensor_parallel_workers": self._num_tensor_parallel_workers,
"num_tensor_parallel_workers": self._parallel_config.tensor_parallel_size,
"max_model_len": self._max_model_len,
"batch_size": attention_input.batch_size,
"prefill_chunk_size": attention_input.prefill_chunk_size,
104 changes: 68 additions & 36 deletions vidur/profiling/attention/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import argparse
import datetime
import itertools
import os
from typing import Any, List

import pandas as pd
import ray
import torch
from sarathi.config import ParallelConfig
from sarathi.model_executor.attention import AttentionBackend
from tqdm import tqdm

from vidur.profiling.attention.attention_input import AttentionInput
from vidur.profiling.attention.attention_wrapper import AttentionWrapper
from vidur.profiling.common.model_config import ModelConfig
from vidur.profiling.utils import get_attention_input_combinations
from vidur.profiling.utils import get_attention_input_combinations, get_max_num_blocks


def parse_args():
@@ -59,7 +60,7 @@ def parse_args():
"--max_model_len",
type=int,
default=4096,
help="Maximum context length model can server",
help="Maximum context length model can serve",
)
parser.add_argument(
"--max_seq_len",
@@ -91,7 +92,7 @@ def parse_args():
)
parser.add_argument(
"--attention_backend",
default=AttentionBackend.FLASH_ATTENTION,
default=AttentionBackend.FLASHINFER,
choices=[e.value for e in AttentionBackend],
help="The attention backend to profile (default: %(default)s)",
)
@@ -112,10 +113,14 @@ def parse_args():
def profile_model(
args: argparse.Namespace,
model: str,
num_tensor_parallel_workers: int,
input_combinations: List[AttentionInput],
max_num_blocks: int,
dtype: torch.dtype,
pbar: Any,
):
model_config = ModelConfig.from_model_name(model)
parallel_config = ParallelConfig(num_tensor_parallel_workers, 1)

promises = []
all_results = []
@@ -127,29 +132,30 @@ def profile_model(
AttentionWrapper,
).options(runtime_env={"env_vars": {"KINETO_LOG_LEVEL": "5"}})

for num_tensor_parallel_workers in args.num_tensor_parallel_workers:
model_wrappers = [
model_wrapper_actor.remote(
model_config,
num_tensor_parallel_workers,
args.max_model_len,
args.block_size,
args.attention_backend,
)
for _ in range(args.num_gpus)
]
model_wrappers = [
model_wrapper_actor.remote(
model_config,
parallel_config,
max_num_blocks,
args.max_model_len,
args.block_size,
args.attention_backend,
dtype,
)
for _ in range(args.num_gpus)
]

for attention_input in input_combinations:
worker_id = len(promises)
promise = model_wrappers[worker_id].profile.remote(attention_input)
promises.append(promise)
for attention_input in input_combinations:
worker_id = len(promises)
promise = model_wrappers[worker_id].profile.remote(attention_input)
promises.append(promise)

if len(promises) >= args.num_gpus:
results = ray.get(promises)
all_results.extend(results)
promises = []
if len(promises) >= args.num_gpus:
results = ray.get(promises)
all_results.extend(results)
promises = []

pbar.update(1)
pbar.update(1)

results = ray.get(promises)
all_results.extend(results)
@@ -170,6 +176,7 @@ def profile_model(
def main():
args = parse_args()

dtype = torch.float16
input_combinations = get_attention_input_combinations(
args.max_seq_len,
args.min_batch_size,
@@ -178,21 +185,46 @@ def main():
args.profile_only_decode,
)

total_combos = itertools.product(
args.models,
args.num_tensor_parallel_workers,
input_combinations,
)
total_combos = {}
max_num_blocks_dict = {}
for model in args.models:
model_config = ModelConfig.from_model_name(model)
for num_tensor_parallel_workers in args.num_tensor_parallel_workers:
max_num_blocks = get_max_num_blocks(
model_config,
ParallelConfig(num_tensor_parallel_workers, 1),
args.block_size,
dtype,
)
max_num_blocks_dict[(model, num_tensor_parallel_workers)] = max_num_blocks
total_combos[(model, num_tensor_parallel_workers)] = list(
filter(
lambda input_combination: input_combination.is_under_memory_limit(
max_num_blocks * args.block_size
),
input_combinations,
)
)

pbar = tqdm(total=len(list(total_combos)))
pbar = tqdm(total=sum(len(v) for v in total_combos.values()))

for model in args.models:
result_df = profile_model(
args,
model,
input_combinations,
pbar,
)
result_df = pd.DataFrame()
for num_tensor_parallel_workers in args.num_tensor_parallel_workers:
result_df = pd.concat(
[
result_df,
profile_model(
args,
model,
num_tensor_parallel_workers,
total_combos[(model, num_tensor_parallel_workers)],
max_num_blocks_dict[(model, num_tensor_parallel_workers)],
dtype,
pbar,
),
]
)
# model name would contain '/', so create a directory as required
os.makedirs(f"{args.output_dir}/{model}", exist_ok=True)
result_df.to_csv(f"{args.output_dir}/{model}/attention.csv", index=False)
2 changes: 0 additions & 2 deletions vidur/profiling/attention/sequence_proxy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import List

import numpy as np


class SequenceProxy:
def __init__(self, total_len: int, processed_len: int):
Loading