Skip to content

Commit

Permalink
#0: TG llama3-70b prefill_decode demo
Browse files Browse the repository at this point in the history
  • Loading branch information
djordje-tt committed Sep 30, 2024
1 parent 08a123c commit 767ebba
Show file tree
Hide file tree
Showing 13 changed files with 724 additions and 481 deletions.
18 changes: 13 additions & 5 deletions models/demos/t3000/llama2_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
import pytest
from models.demos.t3000.llama2_70b.tt.model_config import get_model_config
from models.demos.tg.llama3_70b.tt.model_config import get_model_config as get_galaxy_model_config

MAX_SEQ_LEN = 4096
MAX_SEQ_LEN_LLAMA3 = 8192
Expand Down Expand Up @@ -174,11 +175,18 @@ def setup_llama_env(llama_version="llama3", max_batch_size=32, max_context_len=4
logger.info(f"Tokenizer file: {tokenizer_path}")
logger.info(f"Cache directory: {cache_path}")

model_config = get_model_config(
llama_version=llama_version,
max_batch_size=max_batch_size,
max_context_len=max_context_len,
)
if llama_version == "llama3-tg":
model_config = get_galaxy_model_config(
llama_version=llama_version,
max_batch_size=max_batch_size,
max_context_len=max_context_len,
)
else:
model_config = get_model_config(
llama_version=llama_version,
max_batch_size=max_batch_size,
max_context_len=max_context_len,
)

return model_config, ckpt_dir, tokenizer_path, cache_path

Expand Down
14 changes: 11 additions & 3 deletions models/demos/tg/llama3_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,14 @@ def run_decode(
break

# Decode the entire sequence generated so far and log it
for user_id in range(max(0, bsz - 3), bsz):
text = tokenizer.decode(tokens[user_id, : cur_pos + 1].tolist())
for user_id in range(max(0, bsz - 5), bsz):
eos_found = False
for eos_idx, tk in enumerate(tokens[user_id, : cur_pos + 1].tolist()):
if tk == tokenizer.eos_id:
text = tokenizer.decode(tokens[user_id, :eos_idx].tolist())
eos_found = True
if not eos_found:
text = tokenizer.decode(tokens[user_id, : cur_pos + 1].tolist())
if data_args.print_output_as_generated:
logger.info(f"Loop {cur_pos} user {user_id}: {text}\n")

Expand Down Expand Up @@ -364,7 +370,7 @@ def top_pk_logits_efficient(logits, p=0.9, k=10, temperature=1.0, return_probs=F
),
ids=("chat_completion", "text_completion"),
)
@pytest.mark.parametrize("decode_only", (True,), ids=("decode_only",))
@pytest.mark.parametrize("decode_only", (True, False), ids=("decode_only", "prefill_decode"))
@pytest.mark.parametrize("num_layers", (1, 2, 10, 80), ids=("1L", "2L", "10L", "80L"))
@pytest.mark.parametrize(
"implementation, skip_model_load, n_devices",
Expand Down Expand Up @@ -432,6 +438,8 @@ def test_LlamaModel_demo(
## Get model config
model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env(
llama_version=llama_version,
max_batch_size=max_batch_size,
max_context_len=max_context_len,
)

check_mesh_device(mesh_device, model_config)
Expand Down
23 changes: 2 additions & 21 deletions models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,13 @@
get_rot_transformation_mat,
should_skip_model_load,
check_kv_cache,
)

from models.utility_functions import skip_for_grayskull
from models.demos.t3000.llama2_70b.tt.llama_common import (
setup_llama_env,
check_mesh_device,
extract_pcc_from_log,
generate_rot_emb,
get_rotation_mat,
MAX_SEQ_LEN,
MAX_SEQ_LEN_LLAMA3,
BASE_URL,
UNIT_TEST_N_LAYER,
UNIT_TEST_LAYER_NUM,
UNIT_TEST_START_POS,
UNIT_TEST_GENERATION_LENGTH,
comp_pcc,
get_rot_transformation_mat,
should_skip_model_load,
check_kv_cache,
num_to_corerange,
ConcatMesh2DToTensor,
ShardTensor2dMesh,
)

from models.utility_functions import skip_for_grayskull


class PytorchLlamaAttentionModel(torch.nn.Module):
def __init__(self, hf_reference_model, layer_num, rope_theta):
Expand Down Expand Up @@ -476,7 +458,6 @@ def test_LlamaAttention_inference(
max_batch_size=max_batch_size,
max_context_len=max_context_len,
)

check_mesh_device(mesh_device, model_config)
run_test_LlamaAttention_inference(
mesh_device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger
import torch
import ttnn
from ttnn import ReplicateTensorToMesh, ListMeshToTensor
from ttnn import ReplicateTensorToMesh

from models.demos.t3000.llama2_70b.reference.llama.llama import Llama
from models.demos.tg.llama3_70b.tt.llama_decoder_galaxy import TtLlamaDecoder_galaxy
Expand Down
1 change: 0 additions & 1 deletion models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from loguru import logger
import torch
import ttnn
from ttnn import ListMeshToTensor

from models.demos.t3000.llama2_70b.reference.llama.llama import Llama
from models.demos.tg.llama3_70b.tt.llama_mlp_galaxy import TtLlamaMLP_galaxy
Expand Down
4 changes: 2 additions & 2 deletions models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ def run_test_LlamaModel_inference(
"batch, seq_len",
[
(32, 1),
# (1, 256), (1, 8192), (1, 32768), (1, 128 * 1024)
# (1, 32), (1, 256), (1, 8192), (1, 32768), (1, 128 * 1024)
],
ids=[
"decode",
# "prefill_256", "prefill_8k", "prefill_32k", "prefill_128k"
# "prefill_32", "prefill_256", "prefill_8k", "prefill_32k", "prefill_128k"
],
)
@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit 767ebba

Please sign in to comment.