diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index 67f1be1e04b..e9645e1eaa9 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -136,7 +136,7 @@ def __init__( dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name("wqkv_bias_sharded"), + cache_file_name=cache_name("wqkv_bias_sharded"), ) self.wqkv_bias_prefill = ttnn.reshape(self.wqkv_bias, ttnn.Shape([1, 1, 1, self.wqkv_bias.shape[-1]]))