Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "#0: Revert "#16138: W-broadcasting for sharded tensors"" #17456

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 123 additions & 24 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,14 @@ def test_01_volume_tensors(device, a, b, c_golden, memory_config):
((torch.Size([5, 7, 64, 128]), torch.Size([5, 7, 64, 128])),),
)
@pytest.mark.parametrize(
"sharded_config", [height_sharded_memory_config, width_sharded_memory_config, block_sharded_memory_config]
"sharded_config",
[
height_sharded_memory_config,
width_sharded_memory_config,
block_sharded_memory_config,
],
)
def test_binary_bcast_sharded(a_shape, b_shape, sharded_config, device):
def test_binary_sharded(a_shape, b_shape, sharded_config, device):
input_combinations = (
(ttnn.DRAM_MEMORY_CONFIG, sharded_config),
(sharded_config, ttnn.DRAM_MEMORY_CONFIG),
Expand Down Expand Up @@ -407,10 +412,18 @@ def test_binary_sfpu_ops(input_shapes, dtype, ttnn_fn, device):
b_pt = gen_func_with_cast_tt(partial(torch_random, low=-50, high=50, dtype=torch.float32), dtype)(b_shape)

a_tt = ttnn.from_torch(
a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
a_pt,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
b_tt = ttnn.from_torch(
b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
b_pt,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
cq_id = 0
out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
Expand Down Expand Up @@ -468,13 +481,25 @@ def test_binary_sfpu_opt_out(input_shapes, dtype, ttnn_fn, device):
out = gen_func_with_cast_tt(partial(torch_random, low=0, high=1, dtype=torch.float32), dtype)(out_shape)

a_tt = ttnn.from_torch(
a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
a_pt,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
b_tt = ttnn.from_torch(
b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
b_pt,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
out_tt = ttnn.from_torch(
out, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
out,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
cq_id = 0
ttnn_fn(a_tt, b_tt, queue_id=cq_id, output_tensor=out_tt)
Expand Down Expand Up @@ -517,10 +542,18 @@ def test_binary_sfpu_bitwise_ops(input_shapes, dtype, ttnn_fn, device):
b_pt = gen_func_with_cast_tt(partial(torch_random, low=0, high=31, dtype=torch.int32), dtype)(b_shape)

a_tt = ttnn.from_torch(
a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
a_pt,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
b_tt = ttnn.from_torch(
b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
b_pt,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
cq_id = 0
out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
Expand Down Expand Up @@ -565,13 +598,25 @@ def test_bitwise_opt_output(input_shapes, dtype, ttnn_fn, device):
out = gen_func_with_cast_tt(partial(torch_random, low=0, high=1, dtype=torch.int32), dtype)(out_shape)

a_tt = ttnn.from_torch(
a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
a_pt,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
b_tt = ttnn.from_torch(
b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
b_pt,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
out_tt = ttnn.from_torch(
out, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
out,
dtype=dtype,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
cq_id = 0
ttnn_fn(a_tt, b_tt, queue_id=cq_id, output_tensor=out_tt)
Expand Down Expand Up @@ -650,10 +695,16 @@ def test_inplace_binary_ops_with_tensor(a_shape, b_shape, ttnn_fn, activations,
torch_input_tensor_b, input_tensor_b = rand_bf16_gen(b_shape, device, min=min, max=max)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
torch_input_tensor_a,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
torch_input_tensor_b,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

for golden_activation in golden_lhs:
Expand All @@ -668,14 +719,21 @@ def test_inplace_binary_ops_with_tensor(a_shape, b_shape, ttnn_fn, activations,
for golden_activation in golden_post:
torch_output_tensor = golden_activation(torch_output_tensor).bfloat16()

ttnn_op(input_tensor_a, input_tensor_b, lhs_activations=lhs, rhs_activations=rhs, post_activations=post)
ttnn_op(
input_tensor_a,
input_tensor_b,
lhs_activations=lhs,
rhs_activations=rhs,
post_activations=post,
)
output_tensor = ttnn.to_torch(input_tensor_a)
assert output_tensor.shape == torch_output_tensor.shape

def compare(output_tensor, torch_output_tensor):
imprecise_cases = {
*parameters(
{"logaddexp2_"}, {exp_floor_lhs_exp_rhs, no_activations, sin_rhs, log_lhs_sqrt_abs_post, square_lhs}
{"logaddexp2_"},
{exp_floor_lhs_exp_rhs, no_activations, sin_rhs, log_lhs_sqrt_abs_post, square_lhs},
),
*parameters({"bias_gelu_"}, {no_activations, sin_rhs, square_lhs}),
*parameters({"gt_", "lte_", "gte_", "lt_"}, {sin_rhs, square_lhs}),
Expand Down Expand Up @@ -736,7 +794,11 @@ def test_inplace_bf4b_bf8b(a_shape, b_shape, input_dtype, ttnn_fn, device):
assert output_tensor.shape == torch_output_tensor.shape

def compare(output_tensor, torch_output_tensor, ttnn_fn, input_dtype):
imprecise_cases = {"add_": {ttnn.bfloat4_b}, "sub_": {ttnn.bfloat4_b}, "mul_": {ttnn.bfloat4_b}}
imprecise_cases = {
"add_": {ttnn.bfloat4_b},
"sub_": {ttnn.bfloat4_b},
"mul_": {ttnn.bfloat4_b},
}
if ttnn_fn in imprecise_cases and input_dtype in imprecise_cases[ttnn_fn]:
return ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.97
else:
Expand Down Expand Up @@ -797,10 +859,9 @@ def test_inplace_binary_ops_fp32(input_shapes, ttnn_fn, device):
@pytest.mark.parametrize(
"a_shape, b_shape",
(
(torch.Size([1, 3, 128, 1]), torch.Size([5, 3, 128, 64])),
(torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])),
(torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])),
(torch.Size([16, 1]), torch.Size([1, 1, 32])),
(torch.Size([1, 1, 31, 32]), torch.Size([5, 3, 32, 32])),
(torch.Size([5, 2, 64, 1]), torch.Size([1, 3, 1, 128])),
(torch.Size([5, 1, 1, 64]), torch.Size([2, 3, 128, 1])),
),
)
@pytest.mark.parametrize(
Expand All @@ -814,9 +875,7 @@ def test_inplace_binary_ops_invalid_bcast(a_shape, b_shape, ttnn_fn, device):
_, input_tensor_a = rand_bf16_gen(a_shape, device)
_, input_tensor_b = rand_bf16_gen(b_shape, device)

with pytest.raises(
RuntimeError, match=r"Shape of Output tensor.+ provided does not match the broadcasted output shape .+"
):
with pytest.raises(RuntimeError):
cq_id = 0
ttnn_op(input_tensor_a, input_tensor_b, queue_id=cq_id)

Expand Down Expand Up @@ -898,3 +957,43 @@ def test_binary_opt_output_invalid_bcast(a_shape, b_shape, out_shape, ttnn_fn, d
):
cq_id = 0
ttnn_op(input_tensor_a, input_tensor_b, queue_id=cq_id, output_tensor=out_tt)


def test_binary_sharded_bcast_w(device):
a_shape = torch.Size([5, 7, 2 * 32, 4 * 32])
b_shape = torch.Size([5, 7, 2 * 32, 1])

a_sharded_config = ttnn.create_sharded_memory_config(
[10 * 32, 4 * 32],
core_grid=ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (0, 6))}),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)

b_sharded_config = ttnn.create_sharded_memory_config(
[10 * 32, 32],
core_grid=ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (0, 6))}),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)

input_combinations = (
(ttnn.DRAM_MEMORY_CONFIG, b_sharded_config),
(a_sharded_config, ttnn.DRAM_MEMORY_CONFIG),
(a_sharded_config, b_sharded_config),
)

for src_config, dst_config in input_combinations:
a_pt, a_tt = rand_bf16_gen(a_shape, device, memory_config=src_config)
b_pt, b_tt = rand_bf16_gen(b_shape, device, memory_config=dst_config)

out_pt = torch.add(a_pt, b_pt)
out_tt_sharded = ttnn.experimental.add(a_tt, b_tt, memory_config=ttnn.DRAM_MEMORY_CONFIG)
out_tt_sharded = ttnn.to_torch(out_tt_sharded)
torch.testing.assert_close(out_tt_sharded, out_pt)

out_tt_sharded = ttnn.experimental.add(a_tt, b_tt, memory_config=a_sharded_config)
out_tt_sharded = ttnn.to_torch(out_tt_sharded)
torch.testing.assert_close(out_tt_sharded, out_pt)
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,35 @@ DataType BinaryNgDeviceOperation::operation_attributes_t::get_dtype() const {
return this->dtype.value_or(this->input_dtype);
}

void validate_sharding(
TensorMemoryLayout memory_layout_x,
const ShardSpec& shard_spec_x,
TensorMemoryLayout memory_layout_y,
const ShardSpec& shard_spec_y,
SubtileBroadcastType subtile_broadcast_type) {
TT_FATAL(memory_layout_x == memory_layout_y, "Operands to eltwise binary need to have the same memory layout");

switch (subtile_broadcast_type) {
case SubtileBroadcastType::NONE:
TT_FATAL(shard_spec_x == shard_spec_y, "Operands to eltwise binary need to have the same shard spec");
break;
case SubtileBroadcastType::COL_A:
case SubtileBroadcastType::COL_B:
TT_FATAL(
memory_layout_x == TensorMemoryLayout::HEIGHT_SHARDED,
"Operands to eltwise binary must be height sharded when broadcasting on W");
TT_FATAL(
memory_layout_y == TensorMemoryLayout::HEIGHT_SHARDED,
"Operands to eltwise binary must be height sharded when broadcasting on W");
TT_FATAL(
shard_spec_x.shape[0] == shard_spec_y.shape[0],
"Operands to eltwise binary need to have the same"
"shard height when broadcasting on W");
break;
default: TT_THROW("Invalid subtile broadcast type for sharding validation");
}
}

void BinaryNgDeviceOperation::validate_on_program_cache_miss(
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {
// We don't support sharding for now
Expand Down Expand Up @@ -172,28 +201,28 @@ void BinaryNgDeviceOperation::validate_on_program_cache_miss(
// Validate that all shard specs match
if (tensor_a_sharded) {
if (tensor_b_sharded) {
TT_FATAL(
input_tensor_a.memory_config().memory_layout == input_tensor_b->memory_config().memory_layout,
"Operands to eltwise binary need to have the same memory layout");
TT_FATAL(
input_tensor_a.shard_spec().value() == input_tensor_b->shard_spec().value(),
"Operands to eltwise binary need to have the same shard spec");
validate_sharding(
input_tensor_a.memory_config().memory_layout,
*input_tensor_a.shard_spec(),
input_tensor_b->memory_config().memory_layout,
*input_tensor_b->shard_spec(),
attributes.subtile_broadcast_type);
}
if (output_sharded) {
TT_FATAL(
input_tensor_a.memory_config().memory_layout == attributes.memory_config.memory_layout,
"LHS operand and output to eltwise binary need to have the same memory layout");
TT_FATAL(
input_tensor_a.shard_spec().value() == attributes.memory_config.shard_spec.value(),
"LHS operand and output to eltwise binary need to have the same shard spec");
validate_sharding(
input_tensor_a.memory_config().memory_layout,
*input_tensor_a.shard_spec(),
attributes.memory_config.memory_layout,
*attributes.memory_config.shard_spec,
attributes.subtile_broadcast_type);
}
} else if (tensor_b_sharded and output_sharded) {
TT_FATAL(
input_tensor_b->memory_config().memory_layout == attributes.memory_config.memory_layout,
"RHS operand and output to eltwise binary need to have the same memory layout");
TT_FATAL(
input_tensor_b->shard_spec().value() == attributes.memory_config.shard_spec.value(),
"RHS operand and output to eltwise binary need to have the same shard spec");
validate_sharding(
input_tensor_b->memory_config().memory_layout,
*input_tensor_b->shard_spec(),
attributes.memory_config.memory_layout,
*attributes.memory_config.shard_spec,
attributes.subtile_broadcast_type);
}
}

Expand Down Expand Up @@ -227,10 +256,10 @@ void BinaryNgDeviceOperation::validate_on_program_cache_hit(
a_dim,
b_dim);

if (has_shard_spec) {
if (has_shard_spec and i != -1) {
TT_FATAL(
a_dim == b_dim,
"Cannot broadcast sharded tensors, violation for rank {}, dim a: {}, dim b: {}",
"Cannot broadcast sharded tensors on dims other than W, violation for rank {}, dim a: {}, dim b: {}",
i,
a_dim,
b_dim);
Expand Down Expand Up @@ -284,17 +313,8 @@ BinaryNgDeviceOperation::spec_return_value_t BinaryNgDeviceOperation::compute_ou
}

if (attributes.memory_config.is_sharded()) {
ShardSpec shard_spec{CoreRangeSet(), {0, 0}};
if (input_tensor_a.memory_config().is_sharded()) {
shard_spec = input_tensor_a.shard_spec().value();
} else if (tensor_b.has_value() and tensor_b->memory_config().is_sharded()) {
shard_spec = tensor_b->shard_spec().value();
} else {
shard_spec = attributes.memory_config.shard_spec.value();
}
auto memory_config = attributes.memory_config;
memory_config.shard_spec = shard_spec;
return TensorSpec(output_shape, TensorLayout(attributes.get_dtype(), PageConfig(Layout::TILE), memory_config));
return TensorSpec(
output_shape, TensorLayout(attributes.get_dtype(), PageConfig(Layout::TILE), attributes.memory_config));
}

return TensorSpec(
Expand Down Expand Up @@ -381,8 +401,7 @@ BinaryNgDeviceOperation::invoke(
{rhs_activations.begin(), rhs_activations.end()},
{post_activations.begin(), post_activations.end()},
std::nullopt,
memory_config.value_or(
output_tensor.has_value() ? output_tensor->memory_config() : input_tensor_a.memory_config()),
memory_config.value_or(output_tensor.has_value() ? output_tensor->memory_config() : MemoryConfig{}),
input_tensor_a.get_dtype(),
output_dtype,
get_worker_grid(input_tensor_a, &input_tensor_b, output_tensor),
Expand Down Expand Up @@ -413,8 +432,7 @@ BinaryNgDeviceOperation::invoke(
{rhs_activations.begin(), rhs_activations.end()},
{post_activations.begin(), post_activations.end()},
scalar,
memory_config.value_or(
output_tensor.has_value() ? output_tensor->memory_config() : input_tensor_a.memory_config()),
memory_config.value_or(output_tensor.has_value() ? output_tensor->memory_config() : MemoryConfig{}),
input_tensor_a.get_dtype(),
output_dtype,
get_worker_grid(input_tensor_a, nullptr, output_tensor),
Expand Down
Loading
Loading