Skip to content

Commit

Permalink
#0: Flip TT_ASSERT to TT_FATAL for sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
TT-BrianLiu committed Feb 4, 2025
1 parent 799a11b commit 30e7c5c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
2 changes: 0 additions & 2 deletions tests/ttnn/unit_tests/gtests/tensor/test_tensor_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,6 @@ struct IllegalShardSpecParams {
class IllegalTensorLayoutCreationTests : public ::testing::TestWithParam<IllegalShardSpecParams> {};

TEST_P(IllegalTensorLayoutCreationTests, ExpectFailAndCheckErrMsg) {
GTEST_SKIP() << "Enable tests after flipping asserts to TT_FATAL (issue #17060)";
const auto& params = GetParam();

EXPECT_THAT(
Expand Down Expand Up @@ -1042,7 +1041,6 @@ INSTANTIATE_TEST_SUITE_P(
class IllegalTensorSpecCreationTests : public ::testing::TestWithParam<IllegalShardSpecParams> {};

TEST_P(IllegalTensorSpecCreationTests, ExpectFailAndCheckErrMsg) {
GTEST_SKIP() << "Enable tests after flipping asserts to TT_FATAL (issue #17060)";
const auto& params = GetParam();

auto tensor_layout = TensorLayout(DataType::BFLOAT16, params.page_config, params.memory_config);
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void validate_shard_spec(const TensorLayout& tensor_layout) {
const auto& physical_shard_shape = tensor_layout.get_physical_shard_shape();
const auto& tile_shape = tensor_layout.get_tile().get_tile_shape();
// TODO (issue #17060): Flip to TT_FATAL
TT_ASSERT(
TT_FATAL(
(physical_shard_shape.height() % tile_shape[0] == 0 && physical_shard_shape.width() % tile_shape[1] == 0),
"Physical shard shape {} must be tile {} sized!",
physical_shard_shape,
Expand Down
18 changes: 9 additions & 9 deletions ttnn/cpp/ttnn/tensor/tensor_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,56 +29,56 @@ void validate_shard_spec_with_tensor_shape(const TensorSpec& tensor_spec) {

// TODO (issue #17060): Flip to TT_FATAL
if (memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
TT_ASSERT(
TT_FATAL(
physical_width == physical_shard_width,
"Shard width {} must match physical width {} for height sharded",
physical_shard_width,
physical_width);
uint32_t num_shards = div_up(physical_height, physical_shard_height);
TT_ASSERT(
TT_FATAL(
num_shards <= num_cores,
"Number of shards along height {} must not exceed number of cores {}",
num_shards,
num_cores);
} else if (memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) {
TT_ASSERT(
TT_FATAL(
physical_height == physical_shard_height,
"Shard height {} must match physical height {} for width sharded",
physical_shard_height,
physical_height);
uint32_t num_shards = div_up(physical_width, physical_shard_width);
TT_ASSERT(
TT_FATAL(
num_shards <= num_cores,
"Number of shards along width {} must not exceed number of cores {}",
num_shards,
num_cores);
} else if (memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {
TT_ASSERT(
TT_FATAL(
shard_spec.grid.ranges().size() == 1, "Shard grid must be one full rectangular grid for block sharded!");
uint32_t num_shards_along_height = div_up(physical_height, physical_shard_height);
uint32_t num_shards_along_width = div_up(physical_width, physical_shard_width);

// Additionally check that number of cores along height and width matches shard grid
const CoreCoord shard_grid = shard_spec.grid.bounding_box().grid_size();
if (shard_spec.orientation == ShardOrientation::ROW_MAJOR) {
TT_ASSERT(
TT_FATAL(
num_shards_along_height <= shard_grid.y,
"Number of shards along height {} must not exceed number of rows {} for row major orientation!",
num_shards_along_height,
shard_grid.y);
TT_ASSERT(
TT_FATAL(
num_shards_along_width <= shard_grid.x,
"Number of shards along width {} must not exceed number of columns {} for row major orientation!",
num_shards_along_width,
shard_grid.x);
} else {
TT_ASSERT(
TT_FATAL(
num_shards_along_height <= shard_grid.x,
"Number of shards along height {} must not exceed number of columns {} for column major "
"orientation!",
num_shards_along_height,
shard_grid.x);
TT_ASSERT(
TT_FATAL(
num_shards_along_width <= shard_grid.y,
"Number of shards along width {} must not exceed number of rows {} for column major orientation!",
num_shards_along_width,
Expand Down

0 comments on commit 30e7c5c

Please sign in to comment.