From a47d70fe4ffafea82df7cb8aad8ac328a3347614 Mon Sep 17 00:00:00 2001 From: Djordje Date: Wed, 25 Sep 2024 13:48:18 +0000 Subject: [PATCH] #0: Clean the code --- models/demos/t3000/llama2_70b/tt/model_config.py | 1 - models/demos/tg/llama3_70b/demo/demo.py | 2 +- .../tg/llama3_70b/tests/test_llama_attention_galaxy.py | 2 +- .../tg/llama3_70b/tests/test_llama_model_galaxy.py | 10 ++++++++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/models/demos/t3000/llama2_70b/tt/model_config.py b/models/demos/t3000/llama2_70b/tt/model_config.py index 1de2342a6ce..88b4c36b200 100644 --- a/models/demos/t3000/llama2_70b/tt/model_config.py +++ b/models/demos/t3000/llama2_70b/tt/model_config.py @@ -72,7 +72,6 @@ def get_model_config(llama_version="llama3", max_batch_size=32, max_context_len= "NUM_DEVICES": num_devices, "llama3-tg": MAX_SEQ_LEN_LLAMA3, "llama3.1-tg": MAX_SEQ_LEN_LLAMA3_1, - "PADDING_LENGTH": 32, "COMPUTE_KERNEL_CONFIG": ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi2, math_approx_mode=True, diff --git a/models/demos/tg/llama3_70b/demo/demo.py b/models/demos/tg/llama3_70b/demo/demo.py index 2c259d41590..b2794f9b2b1 100644 --- a/models/demos/tg/llama3_70b/demo/demo.py +++ b/models/demos/tg/llama3_70b/demo/demo.py @@ -7,7 +7,7 @@ import json import torch import torch.nn.functional as F -import ttnn + from time import time import pytest from loguru import logger diff --git a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py index 7fe48df012b..4eb2ff4ed83 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py @@ -458,7 +458,7 @@ def test_LlamaAttention_inference( max_batch_size=max_batch_size, max_context_len=max_context_len, ) - + check_mesh_device(mesh_device, model_config) run_test_LlamaAttention_inference( mesh_device, cluster_shape, diff --git a/models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py index ae303acc3b6..c1f13f7d819 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py @@ -240,8 +240,14 @@ def run_test_LlamaModel_inference( ) @pytest.mark.parametrize( "batch, seq_len", - [(32, 1), (1, 32), (1, 256), (1, 8192), (1, 32768), (1, 128 * 1024)], - ids=["decode", "prefill_32", "prefill_256", "prefill_8k", "prefill_32k", "prefill_128k"], + [ + (32, 1), + # (1, 32), (1, 256), (1, 8192), (1, 32768), (1, 128 * 1024) + ], + ids=[ + "decode", + # "prefill_32", "prefill_256", "prefill_8k", "prefill_32k", "prefill_128k" + ], ) @pytest.mark.parametrize( "max_batch_size, max_context_len",