diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py index e6049715511..cb1248efbd0 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py @@ -339,9 +339,14 @@ def test_01_volume_tensors(device, a, b, c_golden, memory_config): ((torch.Size([5, 7, 64, 128]), torch.Size([5, 7, 64, 128])),), ) @pytest.mark.parametrize( - "sharded_config", [height_sharded_memory_config, width_sharded_memory_config, block_sharded_memory_config] + "sharded_config", + [ + height_sharded_memory_config, + width_sharded_memory_config, + block_sharded_memory_config, + ], ) -def test_binary_bcast_sharded(a_shape, b_shape, sharded_config, device): +def test_binary_sharded(a_shape, b_shape, sharded_config, device): input_combinations = ( (ttnn.DRAM_MEMORY_CONFIG, sharded_config), (sharded_config, ttnn.DRAM_MEMORY_CONFIG), @@ -407,10 +412,18 @@ def test_binary_sfpu_ops(input_shapes, dtype, ttnn_fn, device): b_pt = gen_func_with_cast_tt(partial(torch_random, low=-50, high=50, dtype=torch.float32), dtype)(b_shape) a_tt = ttnn.from_torch( - a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + a_pt, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) b_tt = ttnn.from_torch( - b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + b_pt, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) cq_id = 0 out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id) @@ -468,13 +481,25 @@ def test_binary_sfpu_opt_out(input_shapes, dtype, ttnn_fn, device): out = gen_func_with_cast_tt(partial(torch_random, low=0, high=1, dtype=torch.float32), dtype)(out_shape) a_tt = ttnn.from_torch( - a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + a_pt, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) b_tt = ttnn.from_torch( - b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + b_pt, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) out_tt = ttnn.from_torch( - out, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + out, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) cq_id = 0 ttnn_fn(a_tt, b_tt, queue_id=cq_id, output_tensor=out_tt) @@ -517,10 +542,18 @@ def test_binary_sfpu_bitwise_ops(input_shapes, dtype, ttnn_fn, device): b_pt = gen_func_with_cast_tt(partial(torch_random, low=0, high=31, dtype=torch.int32), dtype)(b_shape) a_tt = ttnn.from_torch( - a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + a_pt, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) b_tt = ttnn.from_torch( - b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + b_pt, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) cq_id = 0 out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id) @@ -565,13 +598,25 @@ def test_bitwise_opt_output(input_shapes, dtype, ttnn_fn, device): out = gen_func_with_cast_tt(partial(torch_random, low=0, high=1, dtype=torch.int32), dtype)(out_shape) a_tt = ttnn.from_torch( - a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + a_pt, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) b_tt = ttnn.from_torch( - b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + b_pt, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) out_tt = ttnn.from_torch( - out, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + out, + dtype=dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) cq_id = 0 ttnn_fn(a_tt, b_tt, queue_id=cq_id, output_tensor=out_tt) @@ -650,10 +695,16 @@ def test_inplace_binary_ops_with_tensor(a_shape, b_shape, ttnn_fn, activations, torch_input_tensor_b, input_tensor_b = rand_bf16_gen(b_shape, device, min=min, max=max) input_tensor_a = ttnn.from_torch( - torch_input_tensor_a, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + torch_input_tensor_a, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) input_tensor_b = ttnn.from_torch( - torch_input_tensor_b, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + torch_input_tensor_b, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) for golden_activation in golden_lhs: @@ -668,14 +719,21 @@ def test_inplace_binary_ops_with_tensor(a_shape, b_shape, ttnn_fn, activations, for golden_activation in golden_post: torch_output_tensor = golden_activation(torch_output_tensor).bfloat16() - ttnn_op(input_tensor_a, input_tensor_b, lhs_activations=lhs, rhs_activations=rhs, post_activations=post) + ttnn_op( + input_tensor_a, + input_tensor_b, + lhs_activations=lhs, + rhs_activations=rhs, + post_activations=post, + ) output_tensor = ttnn.to_torch(input_tensor_a) assert output_tensor.shape == torch_output_tensor.shape def compare(output_tensor, torch_output_tensor): imprecise_cases = { *parameters( - {"logaddexp2_"}, {exp_floor_lhs_exp_rhs, no_activations, sin_rhs, log_lhs_sqrt_abs_post, square_lhs} + {"logaddexp2_"}, + {exp_floor_lhs_exp_rhs, no_activations, sin_rhs, log_lhs_sqrt_abs_post, square_lhs}, ), *parameters({"bias_gelu_"}, {no_activations, sin_rhs, square_lhs}), *parameters({"gt_", "lte_", "gte_", "lt_"}, {sin_rhs, square_lhs}), @@ -736,7 +794,11 @@ def test_inplace_bf4b_bf8b(a_shape, b_shape, input_dtype, ttnn_fn, device): assert output_tensor.shape == torch_output_tensor.shape def compare(output_tensor, torch_output_tensor, ttnn_fn, input_dtype): - imprecise_cases = {"add_": {ttnn.bfloat4_b}, "sub_": {ttnn.bfloat4_b}, "mul_": {ttnn.bfloat4_b}} + imprecise_cases = { + "add_": {ttnn.bfloat4_b}, + "sub_": {ttnn.bfloat4_b}, + "mul_": {ttnn.bfloat4_b}, + } if ttnn_fn in imprecise_cases and input_dtype in imprecise_cases[ttnn_fn]: return ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.97 else: @@ -797,10 +859,9 @@ def test_inplace_binary_ops_fp32(input_shapes, ttnn_fn, device): @pytest.mark.parametrize( "a_shape, b_shape", ( - (torch.Size([1, 3, 128, 1]), torch.Size([5, 3, 128, 64])), - (torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])), - (torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])), - (torch.Size([16, 1]), torch.Size([1, 1, 32])), + (torch.Size([1, 1, 31, 32]), torch.Size([5, 3, 32, 32])), + (torch.Size([5, 2, 64, 1]), torch.Size([1, 3, 1, 128])), + (torch.Size([5, 1, 1, 64]), torch.Size([2, 3, 128, 1])), ), ) @pytest.mark.parametrize( @@ -814,9 +875,7 @@ def test_inplace_binary_ops_invalid_bcast(a_shape, b_shape, ttnn_fn, device): _, input_tensor_a = rand_bf16_gen(a_shape, device) _, input_tensor_b = rand_bf16_gen(b_shape, device) - with pytest.raises( - RuntimeError, match=r"Shape of Output tensor.+ provided does not match the broadcasted output shape .+" - ): + with pytest.raises(RuntimeError): cq_id = 0 ttnn_op(input_tensor_a, input_tensor_b, queue_id=cq_id) @@ -898,3 +957,43 @@ def test_binary_opt_output_invalid_bcast(a_shape, b_shape, out_shape, ttnn_fn, d ): cq_id = 0 ttnn_op(input_tensor_a, input_tensor_b, queue_id=cq_id, output_tensor=out_tt) + + +def test_binary_sharded_bcast_w(device): + a_shape = torch.Size([5, 7, 2 * 32, 4 * 32]) + b_shape = torch.Size([5, 7, 2 * 32, 1]) + + a_sharded_config = ttnn.create_sharded_memory_config( + [10 * 32, 4 * 32], + core_grid=ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (0, 6))}), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + b_sharded_config = ttnn.create_sharded_memory_config( + [10 * 32, 32], + core_grid=ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (0, 6))}), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + input_combinations = ( + (ttnn.DRAM_MEMORY_CONFIG, b_sharded_config), + (a_sharded_config, ttnn.DRAM_MEMORY_CONFIG), + (a_sharded_config, b_sharded_config), + ) + + for src_config, dst_config in input_combinations: + a_pt, a_tt = rand_bf16_gen(a_shape, device, memory_config=src_config) + b_pt, b_tt = rand_bf16_gen(b_shape, device, memory_config=dst_config) + + out_pt = torch.add(a_pt, b_pt) + out_tt_sharded = ttnn.experimental.add(a_tt, b_tt, memory_config=ttnn.DRAM_MEMORY_CONFIG) + out_tt_sharded = ttnn.to_torch(out_tt_sharded) + torch.testing.assert_close(out_tt_sharded, out_pt) + + out_tt_sharded = ttnn.experimental.add(a_tt, b_tt, memory_config=a_sharded_config) + out_tt_sharded = ttnn.to_torch(out_tt_sharded) + torch.testing.assert_close(out_tt_sharded, out_pt) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp index db994629b2d..6dfdcc53a72 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp @@ -119,6 +119,35 @@ DataType BinaryNgDeviceOperation::operation_attributes_t::get_dtype() const { return this->dtype.value_or(this->input_dtype); } +void validate_sharding( + TensorMemoryLayout memory_layout_x, + const ShardSpec& shard_spec_x, + TensorMemoryLayout memory_layout_y, + const ShardSpec& shard_spec_y, + SubtileBroadcastType subtile_broadcast_type) { + TT_FATAL(memory_layout_x == memory_layout_y, "Operands to eltwise binary need to have the same memory layout"); + + switch (subtile_broadcast_type) { + case SubtileBroadcastType::NONE: + TT_FATAL(shard_spec_x == shard_spec_y, "Operands to eltwise binary need to have the same shard spec"); + break; + case SubtileBroadcastType::COL_A: + case SubtileBroadcastType::COL_B: + TT_FATAL( + memory_layout_x == TensorMemoryLayout::HEIGHT_SHARDED, + "Operands to eltwise binary must be height sharded when broadcasting on W"); + TT_FATAL( + memory_layout_y == TensorMemoryLayout::HEIGHT_SHARDED, + "Operands to eltwise binary must be height sharded when broadcasting on W"); + TT_FATAL( + shard_spec_x.shape[0] == shard_spec_y.shape[0], + "Operands to eltwise binary need to have the same" + "shard height when broadcasting on W"); + break; + default: TT_THROW("Invalid subtile broadcast type for sharding validation"); + } +} + void BinaryNgDeviceOperation::validate_on_program_cache_miss( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { // We don't support sharding for now @@ -172,28 +201,28 @@ void BinaryNgDeviceOperation::validate_on_program_cache_miss( // Validate that all shard specs match if (tensor_a_sharded) { if (tensor_b_sharded) { - TT_FATAL( - input_tensor_a.memory_config().memory_layout == input_tensor_b->memory_config().memory_layout, - "Operands to eltwise binary need to have the same memory layout"); - TT_FATAL( - input_tensor_a.shard_spec().value() == input_tensor_b->shard_spec().value(), - "Operands to eltwise binary need to have the same shard spec"); + validate_sharding( + input_tensor_a.memory_config().memory_layout, + *input_tensor_a.shard_spec(), + input_tensor_b->memory_config().memory_layout, + *input_tensor_b->shard_spec(), + attributes.subtile_broadcast_type); } if (output_sharded) { - TT_FATAL( - input_tensor_a.memory_config().memory_layout == attributes.memory_config.memory_layout, - "LHS operand and output to eltwise binary need to have the same memory layout"); - TT_FATAL( - input_tensor_a.shard_spec().value() == attributes.memory_config.shard_spec.value(), - "LHS operand and output to eltwise binary need to have the same shard spec"); + validate_sharding( + input_tensor_a.memory_config().memory_layout, + *input_tensor_a.shard_spec(), + attributes.memory_config.memory_layout, + *attributes.memory_config.shard_spec, + attributes.subtile_broadcast_type); } } else if (tensor_b_sharded and output_sharded) { - TT_FATAL( - input_tensor_b->memory_config().memory_layout == attributes.memory_config.memory_layout, - "RHS operand and output to eltwise binary need to have the same memory layout"); - TT_FATAL( - input_tensor_b->shard_spec().value() == attributes.memory_config.shard_spec.value(), - "RHS operand and output to eltwise binary need to have the same shard spec"); + validate_sharding( + input_tensor_b->memory_config().memory_layout, + *input_tensor_b->shard_spec(), + attributes.memory_config.memory_layout, + *attributes.memory_config.shard_spec, + attributes.subtile_broadcast_type); } } @@ -227,10 +256,10 @@ void BinaryNgDeviceOperation::validate_on_program_cache_hit( a_dim, b_dim); - if (has_shard_spec) { + if (has_shard_spec and i != -1) { TT_FATAL( a_dim == b_dim, - "Cannot broadcast sharded tensors, violation for rank {}, dim a: {}, dim b: {}", + "Cannot broadcast sharded tensors on dims other than W, violation for rank {}, dim a: {}, dim b: {}", i, a_dim, b_dim); @@ -284,17 +313,8 @@ BinaryNgDeviceOperation::spec_return_value_t BinaryNgDeviceOperation::compute_ou } if (attributes.memory_config.is_sharded()) { - ShardSpec shard_spec{CoreRangeSet(), {0, 0}}; - if (input_tensor_a.memory_config().is_sharded()) { - shard_spec = input_tensor_a.shard_spec().value(); - } else if (tensor_b.has_value() and tensor_b->memory_config().is_sharded()) { - shard_spec = tensor_b->shard_spec().value(); - } else { - shard_spec = attributes.memory_config.shard_spec.value(); - } - auto memory_config = attributes.memory_config; - memory_config.shard_spec = shard_spec; - return TensorSpec(output_shape, TensorLayout(attributes.get_dtype(), PageConfig(Layout::TILE), memory_config)); + return TensorSpec( + output_shape, TensorLayout(attributes.get_dtype(), PageConfig(Layout::TILE), attributes.memory_config)); } return TensorSpec( @@ -381,8 +401,7 @@ BinaryNgDeviceOperation::invoke( {rhs_activations.begin(), rhs_activations.end()}, {post_activations.begin(), post_activations.end()}, std::nullopt, - memory_config.value_or( - output_tensor.has_value() ? output_tensor->memory_config() : input_tensor_a.memory_config()), + memory_config.value_or(output_tensor.has_value() ? output_tensor->memory_config() : MemoryConfig{}), input_tensor_a.get_dtype(), output_dtype, get_worker_grid(input_tensor_a, &input_tensor_b, output_tensor), @@ -413,8 +432,7 @@ BinaryNgDeviceOperation::invoke( {rhs_activations.begin(), rhs_activations.end()}, {post_activations.begin(), post_activations.end()}, scalar, - memory_config.value_or( - output_tensor.has_value() ? output_tensor->memory_config() : input_tensor_a.memory_config()), + memory_config.value_or(output_tensor.has_value() ? output_tensor->memory_config() : MemoryConfig{}), input_tensor_a.get_dtype(), output_dtype, get_worker_grid(input_tensor_a, nullptr, output_tensor), diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp index ab2df755dd4..92bb3c8ea55 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp @@ -39,22 +39,60 @@ std::tuple calculate_compute_kernel_args( } } -std::tuple, TensorMemoryLayout> get_shard_spec( - const Tensor& a, const std::optional& b, const Tensor& c) { +struct AllShardSpecs { + ShardSpec a_shard_spec; + ShardSpec b_shard_spec; + ShardSpec c_shard_spec; +}; + +ShardSpec adjust_to_shape( + const ShardSpec& shard_spec, const ttnn::Shape& from_shape, const ttnn::Shape& to_shape) { + auto ret = shard_spec; + + ret.shape[0] = (ret.shape[0] * to_shape[-2]) / from_shape[-2]; + ret.shape[1] = (ret.shape[1] * to_shape[-1]) / from_shape[-1]; + + return ret; +} + +TensorMemoryLayout get_memory_layout(const Tensor& a, const std::optional& b, const Tensor& c) { if (a.memory_config().is_sharded()) { - return {a.shard_spec().value(), a.memory_config().memory_layout}; - } else if (b.has_value() && b->memory_config().is_sharded()) { - return {b->shard_spec().value(), b->memory_config().memory_layout}; - } else if (c.memory_config().is_sharded()) { - return {c.shard_spec().value(), c.memory_config().memory_layout}; + return a.memory_config().memory_layout; + } + if (b.has_value() && b->memory_config().is_sharded()) { + return b->memory_config().memory_layout; } + if (c.memory_config().is_sharded()) { + return c.memory_config().memory_layout; + } + return TensorMemoryLayout::INTERLEAVED; +} - return {std::nullopt, TensorMemoryLayout::INTERLEAVED}; +std::optional get_shard_specs(const Tensor& a, const std::optional& b, const Tensor& c) { + bool a_sharded = a.memory_config().is_sharded(); + bool b_sharded = b.has_value() && b->memory_config().is_sharded(); + bool c_sharded = c.memory_config().is_sharded(); + + if (!a_sharded && !b_sharded && !c_sharded) { + return std::nullopt; + } + + auto a_shape = a.padded_shape(); + auto b_shape = b.has_value() ? b->padded_shape() : ttnn::Shape{1, 1}; + auto c_shape = c.padded_shape(); + + ShardSpec c_shard_spec = c_sharded ? *c.shard_spec() + : a_sharded ? adjust_to_shape(*a.shard_spec(), a_shape, c_shape) + : adjust_to_shape(*b->shard_spec(), b_shape, c_shape); + + return AllShardSpecs{ + a_sharded ? *a.shard_spec() : adjust_to_shape(c_shard_spec, c_shape, a_shape), + b_sharded ? *b->shard_spec() : adjust_to_shape(c_shard_spec, c_shape, b_shape), + c_shard_spec}; } -uint32_t get_shards_per_width( - const CoreRangeSet& all_cores, TensorMemoryLayout memory_layout, ShardOrientation orientation) { - auto num_cores = all_cores.num_cores(); +uint32_t get_shards_per_width(const ShardSpec& shard_spec, TensorMemoryLayout memory_layout) { + auto num_cores = shard_spec.grid.num_cores(); if (memory_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED) { return 1; } @@ -63,12 +101,55 @@ uint32_t get_shards_per_width( return num_cores; } - const auto& bbox = all_cores.bounding_box(); + const auto& bbox = shard_spec.grid.bounding_box(); const auto& start = bbox.start_coord; const auto& end = bbox.end_coord; - return (orientation == ShardOrientation::ROW_MAJOR ? end.x - start.x : end.y - start.y) + 1; + return (shard_spec.orientation == ShardOrientation::ROW_MAJOR ? end.x - start.x : end.y - start.y) + 1; } +class ShardShapeGenerator { + CoreCoord end_core; + bool row_major; + std::array shard_shape; + std::array last_shard_shape; + +public: + ShardShapeGenerator() = default; + + ShardShapeGenerator(const ShardSpec& shard_spec, const Tensor& tensor) : + end_core(shard_spec.grid.ranges().begin()->end_coord), + row_major(shard_spec.orientation == ShardOrientation::ROW_MAJOR) { + auto tile_height = tensor.tensor_spec().tile().get_height(); + auto tile_width = tensor.tensor_spec().tile().get_width(); + + shard_shape = { + tt::round_up(shard_spec.shape[0], tile_height) / tile_height, + tt::round_up(shard_spec.shape[1], tile_width) / tile_width}; + + const auto [N, C, Ht, Wt] = get_shape_dims(tensor); + const auto unrolled_Ht = N * C * Ht; + last_shard_shape = { + shard_shape[0] - (tt::round_up(unrolled_Ht, shard_shape[0]) - unrolled_Ht), + shard_shape[1] - (tt::round_up(Wt, shard_shape[1]) - Wt), + }; + } + + std::array operator()(CoreCoord core) const { + const unsigned majorDim = row_major ? 1 : 0; + const unsigned minorDim = row_major ? 0 : 1; + + auto current_shape = shard_shape; + if (core.x == end_core.x) { + current_shape[majorDim] = last_shard_shape[majorDim]; + } + if (core.y == end_core.y) { + current_shape[minorDim] = last_shard_shape[minorDim]; + } + + return current_shape; + } +}; + template void set_or_update_runtime_arguments( Program& program, @@ -82,18 +163,15 @@ void set_or_update_runtime_arguments( const auto& a = tensor_args.input_tensor_a; const auto& b = tensor_args.input_tensor_b; - const auto ashape = a.padded_shape(); - const auto bshape = b.has_value() ? b->padded_shape() : Shape{1, 1}; - const auto cshape = c.padded_shape(); - const auto [aN, aC, aHt, aWt] = get_shape_dims(a); const auto [bN, bC, bHt, bWt] = b.has_value() ? get_shape_dims(*b) : std::tuple{1u, 1u, 1u, 1u}; const auto [cN, cC, cHt, cWt] = get_shape_dims(c); const uint32_t cHt_unrolled = cN * cC * cHt; bool row_major = true; - const auto [shard_spec, memory_layout] = get_shard_spec(a, b, c); - const bool has_sharding = shard_spec.has_value(); + const auto shard_specs = get_shard_specs(a, b, c); + const bool has_sharding = shard_specs.has_value(); + auto grid = has_sharding ? shard_specs->a_shard_spec.grid : CoreRangeSet{}; // zero_start_grid is a flag to indicate that we are using a single rectangular grid that starts at (0, 0) // as well as having the sharded tensors (if any) start at (0, 0) @@ -106,7 +184,7 @@ void set_or_update_runtime_arguments( const auto& cr = *all_device_cores.ranges().begin(); if (cr.start_coord.x == 0 && cr.start_coord.y == 0) { if (has_sharding) { - const auto& shard_start_coord = shard_spec->grid.ranges()[0].start_coord; + const auto& shard_start_coord = grid.ranges()[0].start_coord; if (shard_start_coord.x == 0 && shard_start_coord.y == 0) { zero_start_grid = true; compute_with_storage_grid = CoreCoord(cr.end_coord.x + 1, cr.end_coord.y + 1); @@ -123,26 +201,28 @@ void set_or_update_runtime_arguments( uint32_t num_tiles_per_core_group_1{}, num_tiles_per_core_group_2{}; CoreRangeSet all_cores, core_group_1, core_group_2; uint32_t num_cores; - CoreCoord end_core; std::vector cores; const uint32_t tile_height = c.tensor_spec().tile().get_height(); const uint32_t tile_width = c.tensor_spec().tile().get_width(); const uint32_t tile_hw = tile_height * tile_width; - const uint32_t num_output_tiles = c.volume() / tile_hw; + const uint32_t c_num_tiles = c.volume() / tile_hw; + uint32_t c_shard_height, c_shard_width, num_shards_per_width; - uint32_t shard_height = cHt_unrolled, shard_width = cWt; - uint32_t last_shard_height = shard_height, last_shard_width = shard_width; + ShardShapeGenerator a_shard_shape_generator; + ShardShapeGenerator b_shard_shape_generator; + ShardShapeGenerator c_shard_shape_generator; if (has_sharding) { - core_group_1 = shard_spec->grid; - num_tiles_per_core_group_1 = shard_spec->numel() / tile_hw; - row_major = shard_spec->orientation == ShardOrientation::ROW_MAJOR; - shard_height = shard_spec->shape[0] / tile_height; - shard_width = shard_spec->shape[1] / tile_width; - end_core = (*shard_spec->grid.ranges().begin()).end_coord; - last_shard_height = shard_height - (tt::round_up(cHt_unrolled, shard_height) - cHt_unrolled); - last_shard_width = shard_width - (tt::round_up(cWt, shard_width) - cWt); + core_group_1 = grid; + a_shard_shape_generator = ShardShapeGenerator(shard_specs->a_shard_spec, a); + if (b.has_value()) { + b_shard_shape_generator = ShardShapeGenerator(shard_specs->b_shard_spec, *b); + } + c_shard_shape_generator = ShardShapeGenerator(shard_specs->c_shard_spec, c); + c_shard_height = shard_specs->c_shard_spec.shape[0] / tile_height; + c_shard_width = shard_specs->c_shard_spec.shape[1] / tile_width; + num_shards_per_width = get_shards_per_width(shard_specs->c_shard_spec, get_memory_layout(a, b, c)); if (zero_start_grid) { auto bbox = core_group_1.bounding_box(); @@ -158,65 +238,52 @@ void set_or_update_runtime_arguments( } else if (zero_start_grid) { std::tie( num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = - tt::tt_metal::split_work_to_cores(compute_with_storage_grid, num_output_tiles, row_major); + tt::tt_metal::split_work_to_cores(compute_with_storage_grid, c_num_tiles, row_major); cores = grid_to_cores(num_cores_total, compute_with_storage_grid.x, compute_with_storage_grid.y, row_major); } else { std::tie( num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = - tt::tt_metal::split_work_to_cores(all_device_cores, num_output_tiles, row_major); + tt::tt_metal::split_work_to_cores(all_device_cores, c_num_tiles, row_major); cores = corerange_to_cores(all_device_cores, {}, row_major); } - auto num_shards_per_width = - has_sharding ? get_shards_per_width(shard_spec->grid, memory_layout, shard_spec->orientation) : 0u; - for (uint32_t i = 0, start_tile_id = 0; i < num_cores_total; i++) { const auto& core = cores[i]; - uint32_t num_tiles_per_core = 0; + uint32_t a_num_tiles = 0; + uint32_t b_num_tiles = 0; + uint32_t c_num_tiles = 0; if (core_group_1.contains(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; + c_num_tiles = num_tiles_per_core_group_1; } else if (core_group_2.contains(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; + c_num_tiles = num_tiles_per_core_group_2; } else { - handle_args(program, reader_kernel_id, core, std::array{0}); - handle_args(program, writer_kernel_id, core, std::array{0}); + handle_args(program, reader_kernel_id, core, std::array{0}); + handle_args(program, writer_kernel_id, core, std::array{0}); handle_args(program, compute_kernel_id, core, std::array{0}); continue; } - uint32_t start_id = 0; - uint32_t current_shard_height = 0; - uint32_t current_shard_width = 0; + uint32_t c_start_id = 0; + uint32_t c_current_shard_width = 0; if (has_sharding) { - current_shard_height = shard_height; - current_shard_width = shard_width; - if (row_major) { - if (core.x == end_core.x) { - current_shard_width = last_shard_width; - } - if (core.y == end_core.y) { - current_shard_height = last_shard_height; - } - } else { - if (core.y == end_core.y) { - current_shard_width = last_shard_width; - } - if (core.x == end_core.x) { - current_shard_height = last_shard_height; - } - } - start_id = (i / num_shards_per_width) * (shard_height * cWt) + (i % num_shards_per_width) * shard_width; - num_tiles_per_core = current_shard_height * current_shard_width; + auto c_shard_shape = c_shard_shape_generator(core); + c_num_tiles = c_shard_shape[0] * c_shard_shape[1]; + c_current_shard_width = c_shard_shape[1]; + auto a_shard_shape = a_shard_shape_generator(core); + a_num_tiles = a_shard_shape[0] * a_shard_shape[1]; + c_start_id = + (i / num_shards_per_width) * (c_shard_height * cWt) + (i % num_shards_per_width) * c_shard_width; } else { - start_id = start_tile_id; + c_start_id = start_tile_id; } std::array reader_runtime_args = { a.buffer()->address(), - start_id, - num_tiles_per_core, - current_shard_width, + c_start_id, + a_num_tiles, + c_num_tiles, + c_current_shard_width, aHt * aWt * aC * (aN > 1), aHt * aWt * (aC > 1), cN, @@ -226,12 +293,17 @@ void set_or_update_runtime_arguments( handle_args(program, reader_kernel_id, core, reader_runtime_args); if (b.has_value()) { + if (has_sharding) { + auto b_shard_shape = b_shard_shape_generator(core); + b_num_tiles = b_shard_shape[0] * b_shard_shape[1]; + } std::array writer_runtime_args = { b->buffer()->address(), c.buffer()->address(), - start_id, - num_tiles_per_core, - current_shard_width, + c_start_id, + b_num_tiles, + c_num_tiles, + c_current_shard_width, bHt * bWt * bC * (bN > 1), bHt * bWt * (bC > 1), cN, @@ -241,8 +313,8 @@ void set_or_update_runtime_arguments( handle_args(program, writer_kernel_id, core, writer_runtime_args); auto [freq, counter] = - calculate_compute_kernel_args(operation_attributes.subtile_broadcast_type, start_id, cHt, cWt); - std::array compute_runtime_args = {num_tiles_per_core, freq, counter}; + calculate_compute_kernel_args(operation_attributes.subtile_broadcast_type, c_start_id, cHt, cWt); + std::array compute_runtime_args = {c_num_tiles, freq, counter}; handle_args(program, compute_kernel_id, core, compute_runtime_args); } else { const auto scalar = *operation_attributes.scalar; @@ -253,22 +325,23 @@ void set_or_update_runtime_arguments( std::array writer_runtime_args = { packed_scalar, c.buffer()->address(), - start_id, - num_tiles_per_core, - current_shard_width, + c_start_id, + c_num_tiles, + c_current_shard_width, cN, cC, cHt, cWt, 0u, + 0u, 0u}; handle_args(program, writer_kernel_id, core, writer_runtime_args); - std::array compute_runtime_args = {num_tiles_per_core, 0u, 0u}; + std::array compute_runtime_args = {c_num_tiles, 0u, 0u}; handle_args(program, compute_kernel_id, core, compute_runtime_args); } - start_tile_id += num_tiles_per_core; + start_tile_id += c_num_tiles; } } @@ -290,9 +363,13 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto program = CreateProgram(); auto* device = a.device(); - auto [shard_spec, memory_layout] = CMAKE_UNIQUE_NAMESPACE::get_shard_spec(a, b, c); - const bool has_sharding = shard_spec.has_value(); - uint32_t num_tiles_per_shard = has_sharding ? shard_spec->numel() / a.tensor_spec().tile().get_tile_hw() : 0; + const auto shard_specs = CMAKE_UNIQUE_NAMESPACE::get_shard_specs(a, b, c); + const bool has_sharding = shard_specs.has_value(); + + auto tile_hw = c.tensor_spec().tile().get_tile_hw(); + uint32_t a_num_tiles_per_shard = has_sharding ? shard_specs->a_shard_spec.numel() / tile_hw : 0; + uint32_t b_num_tiles_per_shard = has_sharding ? shard_specs->b_shard_spec.numel() / tile_hw : 0; + uint32_t c_num_tiles_per_shard = has_sharding ? shard_specs->c_shard_spec.numel() / tile_hw : 0; auto a_data_format = datatype_to_dataformat_converter(a.get_dtype()); auto b_data_format = b.has_value() ? datatype_to_dataformat_converter(b->get_dtype()) @@ -304,8 +381,6 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format); uint32_t c_single_tile_size = tt_metal::detail::TileSize(c_data_format); - uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw(); - // we parallelize the computation across the output tiles constexpr bool row_major = true; const auto& all_device_cores = operation_attributes.worker_grid; @@ -364,7 +439,7 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio program, all_device_cores, a_single_tile_size, - a_sharded ? num_tiles_per_shard : 2, + a_sharded ? a_num_tiles_per_shard : 2, a_data_format, a_sharded ? a_buffer : nullptr); @@ -383,7 +458,7 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio program, all_device_cores, b_single_tile_size, - b_buffer == nullptr ? 1 : (b_sharded ? num_tiles_per_shard : 2), + b_buffer == nullptr ? 1 : (b_sharded ? b_num_tiles_per_shard : 2), b_data_format, b_sharded ? b_buffer : nullptr); @@ -401,7 +476,7 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio program, all_device_cores, c_single_tile_size, - c_sharded ? num_tiles_per_shard : 2, + c_sharded ? c_num_tiles_per_shard : 2, c_data_format, c_sharded ? c_buffer : nullptr); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_col_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_col_bcast.cpp index 9bad94b52d6..4a9021000db 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_col_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_col_bcast.cpp @@ -10,26 +10,21 @@ void kernel_main() { const uint32_t src_addr = get_arg_val(0); const uint32_t start_tile_id = get_arg_val(1); - const uint32_t num_tiles = get_arg_val(2); - const uint32_t shard_width = get_arg_val(3); - const uint32_t n_stride = get_arg_val(4); - const uint32_t c_stride = get_arg_val(5); - const uint32_t N = get_arg_val(6); - const uint32_t C = get_arg_val(7); - const uint32_t Ht = get_arg_val(8); - const uint32_t Wt = get_arg_val(9); + const uint32_t src_num_tiles = get_arg_val(2); + const uint32_t dst_num_tiles = get_arg_val(3); + const uint32_t dst_shard_width = get_arg_val(4); + const uint32_t n_stride = get_arg_val(5); + const uint32_t c_stride = get_arg_val(6); + const uint32_t N = get_arg_val(7); + const uint32_t C = get_arg_val(8); + const uint32_t Ht = get_arg_val(9); + const uint32_t Wt = get_arg_val(10); const uint32_t HtWt = Ht * Wt; constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; - constexpr auto cb_id_src = tt::CBIndex::c_0; constexpr uint32_t onetile = 1; - const uint32_t src_tile_bytes = get_tile_size(cb_id_src); - const DataFormat src_data_format = get_dataformat(cb_id_src); - const InterleavedAddrGenFast src = { - .bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format}; - uint32_t tiles_per_batch = HtWt * C; uint32_t start_n = start_tile_id / tiles_per_batch; uint32_t start_remaining = start_tile_id % tiles_per_batch; @@ -38,25 +33,39 @@ void kernel_main() { uint32_t start_th = start_t / Wt; uint32_t start_tw = start_t % Wt; + constexpr auto cb_id_src = tt::CBIndex::c_0; +#if !SRC_SHARDED + const uint32_t src_tile_bytes = get_tile_size(cb_id_src); + const DataFormat src_data_format = get_dataformat(cb_id_src); + const InterleavedAddrGenFast src = { + .bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format}; + // this is the INPUT tile offset uint32_t tile_offset = start_n * n_stride + start_c * c_stride; uint32_t next_batch_shift = n_stride - c_stride * C; +#endif uint32_t num_tiles_read = 0; - for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_th = 0) { - for (uint32_t th = start_th; th < Ht && num_tiles_read < num_tiles; ++th, start_tw = 0) { + for (uint32_t n = start_n; n < N && num_tiles_read < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_read < dst_num_tiles; ++c, start_th = 0) { + for (uint32_t th = start_th; th < Ht && num_tiles_read < dst_num_tiles; ++th, start_tw = 0) { cb_reserve_back(cb_id_src, onetile); +#if !SRC_SHARDED uint32_t l1_write_addr = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset + th, src, l1_write_addr); noc_async_read_barrier(); +#endif FILL_TILE_WITH_FIRST_COLUMN(cb_id_src); cb_push_back(cb_id_src, onetile); num_tiles_read += Wt - start_tw; } +#if !SRC_SHARDED tile_offset += c_stride; +#endif } +#if !SRC_SHARDED tile_offset += next_batch_shift; +#endif } } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_no_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_no_bcast.cpp index 54ce4655c5a..e73be2e7c6b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_no_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_no_bcast.cpp @@ -9,20 +9,21 @@ void kernel_main() { const uint32_t src_addr = get_arg_val(0); const uint32_t start_tile_id = get_arg_val(1); - const uint32_t num_tiles = get_arg_val(2); - const uint32_t shard_width = get_arg_val(3); - const uint32_t n_stride = get_arg_val(4); - const uint32_t c_stride = get_arg_val(5); - const uint32_t N = get_arg_val(6); - const uint32_t C = get_arg_val(7); - const uint32_t Ht = get_arg_val(8); - const uint32_t Wt = get_arg_val(9); + const uint32_t src_num_tiles = get_arg_val(2); + const uint32_t dst_num_tiles = get_arg_val(3); + const uint32_t dst_shard_width = get_arg_val(4); + const uint32_t n_stride = get_arg_val(5); + const uint32_t c_stride = get_arg_val(6); + const uint32_t N = get_arg_val(7); + const uint32_t C = get_arg_val(8); + const uint32_t Ht = get_arg_val(9); + const uint32_t Wt = get_arg_val(10); constexpr auto cb_id_src = tt::CBIndex::c_0; #if SRC_SHARDED - cb_reserve_back(cb_id_src, num_tiles); - cb_push_back(cb_id_src, num_tiles); + cb_reserve_back(cb_id_src, src_num_tiles); + cb_push_back(cb_id_src, src_num_tiles); #else constexpr uint32_t onetile = 1; constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; @@ -40,7 +41,7 @@ void kernel_main() { uint32_t start_t = start_remaining % HtWt; uint32_t start_th = start_t / Wt; uint32_t start_tw = start_t % Wt; - uint32_t end_tw = has_sharding ? start_tw + shard_width : Wt; + uint32_t end_tw = has_sharding ? start_tw + dst_shard_width : Wt; // this is the INPUT tile offset uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_th * Wt; @@ -48,10 +49,10 @@ void kernel_main() { uint32_t next_batch_shift = n_stride - c_stride * C; uint32_t num_tiles_read = 0; - for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_th = 0) { - for (uint32_t th = start_th; th < Ht && num_tiles_read < num_tiles; ++th, tile_offset += Wt) { - for (uint32_t tw = start_tw; tw < end_tw && num_tiles_read < num_tiles; ++tw, ++num_tiles_read) { + for (uint32_t n = start_n; n < N && num_tiles_read < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_read < dst_num_tiles; ++c, start_th = 0) { + for (uint32_t th = start_th; th < Ht && num_tiles_read < dst_num_tiles; ++th, tile_offset += Wt) { + for (uint32_t tw = start_tw; tw < end_tw && num_tiles_read < dst_num_tiles; ++tw, ++num_tiles_read) { cb_reserve_back(cb_id_src, onetile); uint32_t l1_write_addr_src = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset + tw, src, l1_write_addr_src); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_row_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_row_bcast.cpp index 71ba70e25f1..1463e4f15cb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_row_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_row_bcast.cpp @@ -10,14 +10,15 @@ void kernel_main() { const uint32_t src_addr = get_arg_val(0); const uint32_t start_tile_id = get_arg_val(1); - const uint32_t num_tiles = get_arg_val(2); - const uint32_t shard_width = get_arg_val(3); - const uint32_t n_stride = get_arg_val(4); - const uint32_t c_stride = get_arg_val(5); - const uint32_t N = get_arg_val(6); - const uint32_t C = get_arg_val(7); - const uint32_t Ht = get_arg_val(8); - const uint32_t Wt = get_arg_val(9); + const uint32_t src_num_tiles = get_arg_val(2); + const uint32_t dst_num_tiles = get_arg_val(3); + const uint32_t dst_shard_width = get_arg_val(4); + const uint32_t n_stride = get_arg_val(5); + const uint32_t c_stride = get_arg_val(6); + const uint32_t N = get_arg_val(7); + const uint32_t C = get_arg_val(8); + const uint32_t Ht = get_arg_val(9); + const uint32_t Wt = get_arg_val(10); const uint32_t HtWt = Ht * Wt; constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; @@ -43,10 +44,10 @@ void kernel_main() { uint32_t next_batch_shift = n_stride - c_stride * C; uint32_t num_tiles_read = 0; - for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_th = 0) { - for (uint32_t th = start_th; th < Ht && num_tiles_read < num_tiles; ++th, start_tw = 0) { - for (uint32_t tw = start_tw; tw < Wt && num_tiles_read < num_tiles; ++tw, ++num_tiles_read) { + for (uint32_t n = start_n; n < N && num_tiles_read < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_read < dst_num_tiles; ++c, start_th = 0) { + for (uint32_t th = start_th; th < Ht && num_tiles_read < dst_num_tiles; ++th, start_tw = 0) { + for (uint32_t tw = start_tw; tw < Wt && num_tiles_read < dst_num_tiles; ++tw, ++num_tiles_read) { cb_reserve_back(cb_id_src, onetile); uint32_t l1_write_addr_src = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset + tw, src, l1_write_addr_src); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_scalar_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_scalar_bcast.cpp index 0d8b22a74f8..0927d168641 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_scalar_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_scalar_bcast.cpp @@ -10,14 +10,15 @@ void kernel_main() { const uint32_t src_addr = get_arg_val(0); const uint32_t start_tile_id = get_arg_val(1); - const uint32_t num_tiles = get_arg_val(2); - const uint32_t shard_width = get_arg_val(3); - const uint32_t n_stride = get_arg_val(4); - const uint32_t c_stride = get_arg_val(5); - const uint32_t N = get_arg_val(6); - const uint32_t C = get_arg_val(7); - const uint32_t Ht = get_arg_val(8); - const uint32_t Wt = get_arg_val(9); + const uint32_t src_num_tiles = get_arg_val(2); + const uint32_t dst_num_tiles = get_arg_val(3); + const uint32_t dst_shard_width = get_arg_val(4); + const uint32_t n_stride = get_arg_val(5); + const uint32_t c_stride = get_arg_val(6); + const uint32_t N = get_arg_val(7); + const uint32_t C = get_arg_val(8); + const uint32_t Ht = get_arg_val(9); + const uint32_t Wt = get_arg_val(10); const uint32_t HtWt = Ht * Wt; constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; @@ -41,8 +42,8 @@ void kernel_main() { uint32_t next_batch_shift = n_stride - c_stride * C; uint32_t num_tiles_read = 0; - for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_t = 0) { + for (uint32_t n = start_n; n < N && num_tiles_read < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_read < dst_num_tiles; ++c, start_t = 0) { cb_reserve_back(cb_id_src, onetile); uint32_t l1_write_addr_src = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset, src, l1_write_addr_src); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_col_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_col_bcast.cpp index 776c04b3574..adeb0aa60e6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_col_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_col_bcast.cpp @@ -11,19 +11,29 @@ void kernel_main() { const uint32_t src_addr = get_arg_val(0); const uint32_t dst_addr = get_arg_val(1); const uint32_t start_tile_id = get_arg_val(2); - const uint32_t num_tiles = get_arg_val(3); - const uint32_t shard_width = get_arg_val(4); - const uint32_t n_stride = get_arg_val(5); - const uint32_t c_stride = get_arg_val(6); - const uint32_t N = get_arg_val(7); - const uint32_t C = get_arg_val(8); - const uint32_t Ht = get_arg_val(9); - const uint32_t Wt = get_arg_val(10); + const uint32_t src_num_tiles = get_arg_val(3); + const uint32_t dst_num_tiles = get_arg_val(4); + const uint32_t dst_shard_width = get_arg_val(5); + const uint32_t n_stride = get_arg_val(6); + const uint32_t c_stride = get_arg_val(7); + const uint32_t N = get_arg_val(8); + const uint32_t C = get_arg_val(9); + const uint32_t Ht = get_arg_val(10); + const uint32_t Wt = get_arg_val(11); const uint32_t HtWt = Ht * Wt; constexpr uint32_t onetile = 1; + uint32_t tiles_per_batch = HtWt * C; + uint32_t start_n = start_tile_id / tiles_per_batch; + uint32_t start_remaining = start_tile_id % tiles_per_batch; + uint32_t start_c = start_remaining / HtWt; + uint32_t start_t = start_remaining % HtWt; + uint32_t start_th = start_t / Wt; + uint32_t start_tw = start_t % Wt; + constexpr auto cb_id_src = tt::CBIndex::c_1; +#if !SRC_SHARDED constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; const uint32_t src_tile_bytes = get_tile_size(cb_id_src); const DataFormat src_data_format = get_dataformat(cb_id_src); @@ -31,49 +41,52 @@ void kernel_main() { const InterleavedAddrGenFast src = { .bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format}; + // this is the INPUT tile offset + uint32_t tile_offset = start_n * n_stride + start_c * c_stride; + uint32_t next_batch_shift = n_stride - c_stride * C; +#endif + constexpr auto cb_id_dst = tt::CBIndex::c_2; +#if !DST_SHARDED constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst); const DataFormat dst_data_format = get_dataformat(cb_id_dst); const InterleavedAddrGenFast dst = { .bank_base_address = dst_addr, .page_size = dst_tile_bytes, .data_format = dst_data_format}; - - uint32_t tiles_per_batch = HtWt * C; - uint32_t start_n = start_tile_id / tiles_per_batch; - uint32_t start_remaining = start_tile_id % tiles_per_batch; - uint32_t start_c = start_remaining / HtWt; - uint32_t start_t = start_remaining % HtWt; - uint32_t start_th = start_t / Wt; - uint32_t start_tw = start_t % Wt; - - // this is the INPUT tile offset - uint32_t tile_offset = start_n * n_stride + start_c * c_stride; - uint32_t next_batch_shift = n_stride - c_stride * C; +#endif uint32_t num_tiles_written = 0; - for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_th = 0) { - for (uint32_t th = start_th; th < Ht && num_tiles_written < num_tiles; ++th, start_tw = 0) { + for (uint32_t n = start_n; n < N && num_tiles_written < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_written < dst_num_tiles; ++c, start_th = 0) { + for (uint32_t th = start_th; th < Ht && num_tiles_written < dst_num_tiles; ++th, start_tw = 0) { // read a tile from src cb_reserve_back(cb_id_src, onetile); +#if !SRC_SHARDED uint32_t l1_write_addr = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset + th, src, l1_write_addr); noc_async_read_barrier(); +#endif FILL_TILE_WITH_FIRST_COLUMN(cb_id_src); cb_push_back(cb_id_src, onetile); - for (uint32_t tw = start_tw; tw < Wt && num_tiles_written < num_tiles; ++tw, ++num_tiles_written) { + for (uint32_t tw = start_tw; tw < Wt && num_tiles_written < dst_num_tiles; ++tw, ++num_tiles_written) { // write a tile to dst, since the dst shape is full, the tile offset simply grows linearly cb_wait_front(cb_id_dst, onetile); +#if !DST_SHARDED uint32_t l1_read_addr = get_read_ptr(cb_id_dst); noc_async_write_tile(start_tile_id + num_tiles_written, dst, l1_read_addr); noc_async_write_barrier(); cb_pop_front(cb_id_dst, onetile); +#endif } } +#if !SRC_SHARDED tile_offset += c_stride; +#endif } +#if !SRC_SHARDED tile_offset += next_batch_shift; +#endif } } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_no_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_no_bcast.cpp index 740d061b708..8408aa3cb25 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_no_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_no_bcast.cpp @@ -10,22 +10,23 @@ void kernel_main() { const uint32_t src_addr = get_arg_val(0); const uint32_t dst_addr = get_arg_val(1); - uint32_t start_tile_id = get_arg_val(2); - const uint32_t num_tiles = get_arg_val(3); - const uint32_t shard_width = get_arg_val(4); - const uint32_t n_stride = get_arg_val(5); - const uint32_t c_stride = get_arg_val(6); - const uint32_t N = get_arg_val(7); - const uint32_t C = get_arg_val(8); - const uint32_t Ht = get_arg_val(9); - const uint32_t Wt = get_arg_val(10); + const uint32_t start_tile_id = get_arg_val(2); + const uint32_t src_num_tiles = get_arg_val(3); + const uint32_t dst_num_tiles = get_arg_val(4); + const uint32_t dst_shard_width = get_arg_val(5); + const uint32_t n_stride = get_arg_val(6); + const uint32_t c_stride = get_arg_val(7); + const uint32_t N = get_arg_val(8); + const uint32_t C = get_arg_val(9); + const uint32_t Ht = get_arg_val(10); + const uint32_t Wt = get_arg_val(11); constexpr uint32_t onetile = 1; constexpr auto cb_id_src = tt::CBIndex::c_1; #if SRC_SHARDED - cb_reserve_back(cb_id_src, num_tiles); - cb_push_back(cb_id_src, num_tiles); + cb_reserve_back(cb_id_src, src_num_tiles); + cb_push_back(cb_id_src, src_num_tiles); #else constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; const uint32_t src_tile_bytes = get_tile_size(cb_id_src); @@ -55,7 +56,7 @@ void kernel_main() { uint32_t start_t = start_remaining % HtWt; uint32_t start_th = start_t / Wt; uint32_t start_tw = start_t % Wt; - uint32_t end_tw = has_sharding ? start_tw + shard_width : Wt; + uint32_t end_tw = has_sharding ? start_tw + dst_shard_width : Wt; // this is the INPUT tile offset uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_th * Wt; @@ -63,10 +64,12 @@ void kernel_main() { uint32_t next_batch_shift = n_stride - c_stride * C; uint32_t num_tiles_written = 0; - for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_th = 0) { - for (uint32_t th = start_th; th < Ht && num_tiles_written < num_tiles; ++th) { - for (uint32_t tw = start_tw; tw < end_tw && num_tiles_written < num_tiles; ++tw, ++num_tiles_written) { + uint32_t dst_tile_offset = start_tile_id; + for (uint32_t n = start_n; n < N && num_tiles_written < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_written < dst_num_tiles; ++c, start_th = 0) { + for (uint32_t th = start_th; th < Ht && num_tiles_written < dst_num_tiles; ++th) { + for (uint32_t tw = start_tw; tw < end_tw && num_tiles_written < dst_num_tiles; + ++tw, ++num_tiles_written) { #if !SRC_SHARDED // read a tile from src cb_reserve_back(cb_id_src, onetile); @@ -80,7 +83,7 @@ void kernel_main() { // write a tile to dst, since the dst shape is full, the tile offset simply grows linearly cb_wait_front(cb_id_dst, onetile); uint32_t l1_read_addr = get_read_ptr(cb_id_dst); - noc_async_write_tile(start_tile_id + num_tiles_written, dst, l1_read_addr); + noc_async_write_tile(dst_tile_offset + num_tiles_written, dst, l1_read_addr); noc_async_write_barrier(); cb_pop_front(cb_id_dst, onetile); #endif @@ -88,7 +91,7 @@ void kernel_main() { tile_offset += Wt; if constexpr (has_sharding) { // adjust the output tile offset since we had to skip parts of the row - start_tile_id += (Wt - shard_width); + dst_tile_offset += (Wt - dst_shard_width); } else { // otherwise, next row of tiles should start at the first column start_tw = 0; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_row_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_row_bcast.cpp index fce60f73ea7..95f4d1ffcc7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_row_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_row_bcast.cpp @@ -11,14 +11,15 @@ void kernel_main() { const uint32_t src_addr = get_arg_val(0); const uint32_t dst_addr = get_arg_val(1); const uint32_t start_tile_id = get_arg_val(2); - const uint32_t num_tiles = get_arg_val(3); - const uint32_t shard_width = get_arg_val(4); - const uint32_t n_stride = get_arg_val(5); - const uint32_t c_stride = get_arg_val(6); - const uint32_t N = get_arg_val(7); - const uint32_t C = get_arg_val(8); - const uint32_t Ht = get_arg_val(9); - const uint32_t Wt = get_arg_val(10); + const uint32_t src_num_tiles = get_arg_val(3); + const uint32_t dst_num_tiles = get_arg_val(4); + const uint32_t dst_shard_width = get_arg_val(5); + const uint32_t n_stride = get_arg_val(6); + const uint32_t c_stride = get_arg_val(7); + const uint32_t N = get_arg_val(8); + const uint32_t C = get_arg_val(9); + const uint32_t Ht = get_arg_val(10); + const uint32_t Wt = get_arg_val(11); const uint32_t HtWt = Ht * Wt; constexpr uint32_t onetile = 1; @@ -52,10 +53,10 @@ void kernel_main() { uint32_t next_batch_shift = n_stride - c_stride * C; uint32_t num_tiles_written = 0; - for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_th = 0) { - for (uint32_t th = start_th; th < Ht && num_tiles_written < num_tiles; ++th, start_tw = 0) { - for (uint32_t tw = start_tw; tw < Wt && num_tiles_written < num_tiles; ++tw, ++num_tiles_written) { + for (uint32_t n = start_n; n < N && num_tiles_written < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_written < dst_num_tiles; ++c, start_th = 0) { + for (uint32_t th = start_th; th < Ht && num_tiles_written < dst_num_tiles; ++th, start_tw = 0) { + for (uint32_t tw = start_tw; tw < Wt && num_tiles_written < dst_num_tiles; ++tw, ++num_tiles_written) { // read a tile from src cb_reserve_back(cb_id_src, onetile); uint32_t l1_write_addr = get_write_ptr(cb_id_src); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar.cpp index 311aa075de6..17a5ec998c1 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar.cpp @@ -12,8 +12,8 @@ void kernel_main() { const uint32_t packed_scalar = get_arg_val(0); const uint32_t dst_addr = get_arg_val(1); const uint32_t start_tile_id = get_arg_val(2); - const uint32_t num_tiles = get_arg_val(3); - const uint32_t shard_width = get_arg_val(4); + const uint32_t dst_num_tiles = get_arg_val(3); + const uint32_t dst_shard_width = get_arg_val(4); const uint32_t N = get_arg_val(5); const uint32_t C = get_arg_val(6); const uint32_t Ht = get_arg_val(7); @@ -50,9 +50,9 @@ void kernel_main() { cb_push_back(cb_id_src, onetile); uint32_t num_tiles_written = 0; - for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_t = 0) { - for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) { + for (uint32_t n = start_n; n < N && num_tiles_written < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_written < dst_num_tiles; ++c, start_t = 0) { + for (uint32_t t = start_t; t < HtWt && num_tiles_written < dst_num_tiles; ++t, ++num_tiles_written) { // write a tile to dst, since the dst shape is full, the tile offset simply grows linearly cb_wait_front(cb_id_dst, onetile); uint32_t l1_read_addr = get_read_ptr(cb_id_dst); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar_bcast.cpp index 7ded85af914..d3c3baf9c13 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar_bcast.cpp @@ -11,14 +11,15 @@ void kernel_main() { const uint32_t src_addr = get_arg_val(0); const uint32_t dst_addr = get_arg_val(1); const uint32_t start_tile_id = get_arg_val(2); - const uint32_t num_tiles = get_arg_val(3); - const uint32_t shard_width = get_arg_val(4); - const uint32_t n_stride = get_arg_val(5); - const uint32_t c_stride = get_arg_val(6); - const uint32_t N = get_arg_val(7); - const uint32_t C = get_arg_val(8); - const uint32_t Ht = get_arg_val(9); - const uint32_t Wt = get_arg_val(10); + const uint32_t src_num_tiles = get_arg_val(3); + const uint32_t dst_num_tiles = get_arg_val(4); + const uint32_t dst_shard_width = get_arg_val(5); + const uint32_t n_stride = get_arg_val(6); + const uint32_t c_stride = get_arg_val(7); + const uint32_t N = get_arg_val(8); + const uint32_t C = get_arg_val(9); + const uint32_t Ht = get_arg_val(10); + const uint32_t Wt = get_arg_val(11); const uint32_t HtWt = Ht * Wt; constexpr uint32_t onetile = 1; @@ -50,8 +51,8 @@ void kernel_main() { uint32_t next_batch_shift = n_stride - c_stride * C; uint32_t num_tiles_written = 0; - for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) { - for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_t = 0) { + for (uint32_t n = start_n; n < N && num_tiles_written < dst_num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_written < dst_num_tiles; ++c, start_t = 0) { // read a tile from src cb_reserve_back(cb_id_src, onetile); uint32_t l1_write_addr = get_write_ptr(cb_id_src); @@ -60,7 +61,7 @@ void kernel_main() { FILL_TILE_WITH_FIRST_ELEMENT(cb_id_src); cb_push_back(cb_id_src, onetile); - for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) { + for (uint32_t t = start_t; t < HtWt && num_tiles_written < dst_num_tiles; ++t, ++num_tiles_written) { // write a tile to dst, since the dst shape is full, the tile offset simply grows linearly cb_wait_front(cb_id_dst, onetile); uint32_t l1_read_addr = get_read_ptr(cb_id_dst);