Skip to content

Commit

Permalink
#16948: Add fixes to support unpadded inner dim for pf + mm and add s…
Browse files Browse the repository at this point in the history
…upport for DRAM in1 inputs for ring matmul. (#17311)

### Ticket
- #17060
- #16948

### Problem description
With the new validation being added to tensor_spec, the current
implementation of ring matmul with unpadded shapes fails.

Also, the matmul currently does not support DRAM interleaved in1
weights, which is required for matmuls with large weights that cannot
fit in L1 (such as the LM head in Llama).

### What's changed
- Internally round up the inner dim
- Add support for DRAM_INTERLEAVED in1
  • Loading branch information
avoraTT authored Feb 2, 2025
1 parent 69592bb commit 5db3a16
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def run_multi_core_matmul_1d(
pcc_threshold=0.98,
use_physical_to_logical_mapping=True,
hop_grid=None,
in1_is_dram_interleaved=False,
):
assert not has_bias, "Bias not supported for gather_in0 mode."
if not isinstance(grid, tuple) and not use_arbitrary_cores:
Expand Down Expand Up @@ -237,14 +238,18 @@ def run_multi_core_matmul_1d(
),
)

in1_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
core_range_set,
[K_padded, N_per_shard],
ttnn.ShardOrientation.ROW_MAJOR,
),
in1_sharded_mem_config = (
ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
core_range_set,
[K, N_per_shard],
ttnn.ShardOrientation.ROW_MAJOR,
),
)
if not in1_is_dram_interleaved
else ttnn.DRAM_MEMORY_CONFIG
)

output_sharded_mem_config = ttnn.MemoryConfig(
Expand Down Expand Up @@ -328,6 +333,85 @@ def run_multi_core_matmul_1d(
assert device.num_program_cache_entries() == 1 # Only 1 op


@pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32")
@pytest.mark.skipif(is_blackhole(), reason="Test suite for GS only")
@pytest.mark.parametrize("has_bias", [False], ids=["no_bias"])
@pytest.mark.parametrize(
"B, M, K, N, in0_dtype, in1_dtype, fidelity, packer_l1_acc, fp32_acc_mode, grid",
[
(1, 32, 2048, 3584, ttnn.bfloat8_b, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, (8, 3)),
(1, 32, 2048, 16 * 1024, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.HiFi2, False, False, (8, 4)),
(1, 32, 7520, 8192, ttnn.bfloat8_b, ttnn.bfloat16, ttnn.MathFidelity.HiFi4, True, True, (6, 7)),
],
)
@pytest.mark.parametrize(
"activation",
[
None,
],
)
@pytest.mark.parametrize(
"use_arbitrary_cores, hop_grid",
[
(False, None),
(False, [(3, 6)]),
(True, None),
],
)
@pytest.mark.parametrize(
"in1_is_dram_interleaved",
[
True,
],
)
@pytest.mark.parametrize(
"num_iters",
[
3,
],
)
def test_multi_core_matmul_1d_in1_dram_wh(
device,
in0_dtype,
in1_dtype,
fidelity,
has_bias,
fp32_acc_mode,
packer_l1_acc,
B,
M,
K,
N,
activation,
grid,
hop_grid,
use_arbitrary_cores,
in1_is_dram_interleaved,
num_iters,
use_program_cache,
function_level_defaults,
):
run_multi_core_matmul_1d(
device,
in0_dtype,
in1_dtype,
fidelity,
has_bias,
fp32_acc_mode,
packer_l1_acc,
B,
M,
K,
N,
activation,
grid,
use_arbitrary_cores,
num_iters,
hop_grid=hop_grid,
in1_is_dram_interleaved=in1_is_dram_interleaved,
)


@pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32")
@pytest.mark.skipif(is_blackhole(), reason="Test suite for GS only")
@pytest.mark.parametrize("has_bias", [False], ids=["no_bias"])
Expand Down
3 changes: 2 additions & 1 deletion tests/ttnn/unit_tests/operations/prefetcher_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def run_prefetcher_mm(

tt_tensors_all = []
for tid in range(num_tensors * num_layers):
K, N = padded_shapes[tid % num_tensors]
K, _ = input_shapes[tid % num_tensors]
_, N = padded_shapes[tid % num_tensors]
input_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ void MAIN {
constexpr uint32_t batch = get_compile_time_arg_val(13); // batch dim
constexpr uint32_t out_block_num_tiles = get_compile_time_arg_val(14); // number of tiles in out_block
constexpr bool untilize_out = get_compile_time_arg_val(15); // untilize output
constexpr bool in1_is_dram_interleaved = get_compile_time_arg_val(16); // in1 is in dram
constexpr uint32_t ring_size = num_blocks;

// Runtime args
Expand Down Expand Up @@ -223,6 +224,11 @@ void MAIN {
const uint32_t curr_ring_idx = (ring_idx + block) % ring_size;
uint32_t unpadded_in0_block_w = unpadded_in0_shard_widths_in_tiles[curr_ring_idx];

// Wait for in1 block
if constexpr (in1_is_dram_interleaved) {
cb_wait_front(in1_cb_id, in1_block_num_tiles);
}

const uint32_t input0_cb_id = block == 0 ? in0_cb_id : in2_cb_id;
bool last_out = block == (num_blocks - 1);
// Configure packer once for pack out without Bias
Expand Down Expand Up @@ -258,7 +264,8 @@ void MAIN {
#ifdef ENABLE_GLOBAL_CB
int in1_index_subblock_offset = 0;
#else
int in1_index_subblock_offset = in1_block_num_tiles * (curr_ring_idx);
// This should always be 0 when reading in1 from DRAM
int in1_index_subblock_offset = in1_is_dram_interleaved ? 0 : in1_block_num_tiles * (curr_ring_idx);
#endif
for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) {
tile_regs_acquire();
Expand Down Expand Up @@ -377,6 +384,9 @@ void MAIN {
#endif

cb_pop_front(input0_cb_id, in0_block_num_tiles);
if constexpr (in1_is_dram_interleaved) {
cb_pop_front(in1_cb_id, in1_block_num_tiles);
}
#ifdef ENABLE_GLOBAL_CB
curr_in1_block_index = next_in1_block_index;
UNPACK((update_local_cb_rd_ptr(in1_cb_id, next_in1_rd_ptr_addr)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,58 @@
#include "hostdevcommon/common_values.hpp"
#include "remote_circular_buffer_api.h"
#include "debug/dprint.h"
#include "debug/dprint_tile.h"

template <bool DRAM, uint32_t tile_hw>
void read_block_from_dram(
uint32_t cb_id,
InterleavedAddrGenFast<DRAM, tile_hw> s1,
uint32_t tensor_width_in_tiles,
uint32_t block_w_idx,
uint32_t block_h_idx,
uint32_t block_w_t,
uint32_t block_h_t,
uint32_t tile_size_bytes) {
uint32_t l1_write_addr = get_write_ptr(cb_id);

// Horizontal idx + vertical idx * width = row major index
uint32_t block_tile_id = block_w_idx * block_w_t + (block_h_idx * block_h_t) * tensor_width_in_tiles;
for (uint32_t h = 0; h < block_h_t; ++h) {
uint32_t tile_id = block_tile_id + h * tensor_width_in_tiles;
for (uint32_t w = 0; w < block_w_t; ++w) {
noc_async_read_tile(tile_id + w, s1, l1_write_addr);
l1_write_addr += tile_size_bytes;
}
}
noc_async_read_barrier();
}

void kernel_main() {
// Compile time args
constexpr uint32_t shard_width_in_tiles = get_compile_time_arg_val(0);
constexpr uint32_t shard_height_in_tiles = get_compile_time_arg_val(1);
constexpr uint32_t num_blocks = get_compile_time_arg_val(2);
constexpr uint32_t in1_block_num_tiles = get_compile_time_arg_val(3);
constexpr uint32_t batch = get_compile_time_arg_val(4);
constexpr const bool in1_is_dram_interleaved = get_compile_time_arg_val(0);
constexpr uint32_t in1_block_height_in_tiles = get_compile_time_arg_val(1); // Padded block shape
constexpr uint32_t in1_block_width_in_tiles = get_compile_time_arg_val(2);
constexpr uint32_t in1_tensor_width_in_tiles = get_compile_time_arg_val(3);
constexpr uint32_t num_blocks = get_compile_time_arg_val(4);
constexpr uint32_t batch = get_compile_time_arg_val(5);

uint32_t rt_args_idx = 0;
const uint32_t in1_tensor_addr = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t ring_idx = get_arg_val<uint32_t>(rt_args_idx++);

constexpr uint32_t cb_id_in1 = tt::CBIndex::c_1;
constexpr uint32_t sync_cb = tt::CBIndex::c_3;
constexpr uint32_t sync_cb2 = tt::CBIndex::c_4;
constexpr uint32_t remote_cb_id = tt::CBIndex::c_31;
constexpr uint32_t shard_size_in_tiles = shard_width_in_tiles * shard_height_in_tiles;

const uint32_t in1_block_num_tiles = in1_block_height_in_tiles * in1_block_width_in_tiles;

// Address setup
constexpr const uint32_t in1_tile_hw = get_tile_hw(cb_id_in1);
constexpr uint32_t in1_single_tile_size_bytes = get_tile_size(cb_id_in1);
constexpr DataFormat in1_data_format = get_dataformat(cb_id_in1);
const InterleavedAddrGenFast<in1_is_dram_interleaved, in1_tile_hw> s1 = {
.bank_base_address = in1_tensor_addr, .page_size = in1_single_tile_size_bytes, .data_format = in1_data_format};

for (uint32_t b = 0; b < batch; ++b) {
cb_reserve_back(sync_cb2, 1);
Expand All @@ -31,6 +69,24 @@ void kernel_main() {

cb_push_back(sync_cb2, 1);

if constexpr (in1_is_dram_interleaved) {
for (uint32_t block = 0; block < num_blocks; ++block) {
uint32_t block_idx = (ring_idx + block) % num_blocks;

cb_reserve_back(cb_id_in1, in1_block_num_tiles);
read_block_from_dram(
cb_id_in1,
s1,
in1_tensor_width_in_tiles,
ring_idx,
block_idx,
in1_block_width_in_tiles,
in1_block_height_in_tiles,
in1_single_tile_size_bytes);
cb_push_back(cb_id_in1, in1_block_num_tiles);
}
}

#ifdef ENABLE_GLOBAL_CB
cb_wait_front(sync_cb, 1);
experimental::remote_cb_pop_front(remote_cb_id, num_blocks);
Expand Down
29 changes: 26 additions & 3 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1563,13 +1563,17 @@ void Matmul::validate(

// Gather in0 specific validation
if (program_config.gather_in0) {
TT_FATAL(
program_config.num_global_cb_receivers > 0, "Num global CB receivers must be greater than 0.");
TT_FATAL(
input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED,
"Input tensor A must be width sharded when using gather_in0.");
TT_FATAL(
input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED,
"Input tensor B must be width sharded when using gather_in0.");
if (!this->global_cb.has_value()) {
input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED ||
(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED &&
input_tensor_b.buffer()->buffer_type() == tt_metal::BufferType::DRAM),
"Input tensor B must be width sharded or DRAM interleaved when using gather_in0.");
if (!this->global_cb.has_value() && input_tensor_b.is_sharded()) {
TT_FATAL(
input_tensor_a.shard_spec().value().grid == input_tensor_b.shard_spec().value().grid,
"Input tensor A and B must be sharded on the same cores "
Expand All @@ -1581,6 +1585,25 @@ void Matmul::validate(
this->output_mem_config.shard_spec.has_value(),
"Output shard spec must be provided when using gather_in0.");

if (!input_tensor_b.is_sharded()) {
TT_FATAL(
!this->global_cb.has_value(),
"Global CB is not supported for DRAM_INTERLEAVED in1 when using gather_in0.");
TT_FATAL(
input_tensor_b.get_layout() == Layout::TILE,
"Input tensor B must be TILE_LAYOUT when DRAM_INTERLEAVED when using gather_in0.");
TT_FATAL(
input_tensor_a.shard_spec().value().grid == this->output_mem_config.shard_spec.value().grid,
"Input tensor A and output tensor must be sharded on the same cores when using gather_in0 "
"and in1 is DRAM_INTERLEAVED.");
}

if (!this->global_cb.has_value()) {
TT_FATAL(
program_config.num_global_cb_receivers == 1,
"Num global CB receivers must be 1 when global CB is not provided.");
}

TT_FATAL(!optional_bias.has_value(), "Bias is not supported when using gather_in0.");
}
if (program_config.mcast_in0 || program_config.gather_in0) {
Expand Down
Loading

0 comments on commit 5db3a16

Please sign in to comment.