Skip to content

Commit

Permalink
#17477: Adopt ND coordinate system in MeshDeviceView and the relate…
Browse files Browse the repository at this point in the history
…d abstractions (#18073)

### Ticket
#17477

### Problem description
Continuing plumbing ND coordinate system across Metal / TTNN.

### What's changed
* Adopted `SimpleMeshShape` in `MeshDeviceView`.
* Removed `Coordinate`.
* Simplified `MeshDeviceView` construction (unused `CoordinateMapper`),
simplified getting line/ring coordinates.
* Support ND rotation when requesting specific mesh shape from
`SystemMesh`.
* More features in ND `MeshContainer`, `MeshCoordinate`,
`MeshCoordinateRange`.

### Checklist
- [x] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13441887717)
- pending
- [Build failures in programming examples fixed and
verified](https://github.com/tenstorrent/tt-metal/actions/runs/13444193986)
- [X] New/Existing tests provide coverage for changes
- [X] Ran the affected T3K distributed tests locally
(`unit_tests_ttnn_cc`, `unit_tests_ttnn_tensor`, `test_distributed`,
`distributed_unit_tests_wormhole_b0`).
  • Loading branch information
omilyutin-tt authored Feb 20, 2025
1 parent 0df1047 commit f4719c7
Show file tree
Hide file tree
Showing 33 changed files with 542 additions and 449 deletions.
12 changes: 6 additions & 6 deletions tests/tt_metal/distributed/test_mesh_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ TEST_F(MeshBufferTestT3000, GetDeviceBuffer) {
MeshBuffer::create(ReplicatedBufferConfig{.size = 16 << 10}, device_local_config, mesh_device_.get());

// Out of bounds coordinates.
EXPECT_ANY_THROW(replicated_buffer->get_device_buffer(Coordinate{2, 4}));
EXPECT_ANY_THROW(replicated_buffer->get_device_buffer(MeshCoordinate{2, 4}));

EXPECT_NO_THROW(replicated_buffer->get_device_buffer(Coordinate{1, 3}));
EXPECT_NO_THROW(replicated_buffer->get_device_buffer(MeshCoordinate{1, 3}));
}

class DeviceLocalMeshBufferShardingTest
Expand Down Expand Up @@ -174,14 +174,14 @@ TEST_P(DeviceLocalMeshBufferShardingTest, ShardingTest) {

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x));
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, MeshCoordinate(logical_y, logical_x));
}
}

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
std::vector<uint32_t> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, MeshCoordinate(logical_y, logical_x));
EXPECT_EQ(dst_vec, src_vec);
}
}
Expand Down Expand Up @@ -304,14 +304,14 @@ TEST_F(MeshBufferTestSuite, InterleavedShardsReadWrite) {
std::iota(src_vec.begin(), src_vec.end(), i);
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, Coordinate(logical_y, logical_x));
WriteShard(mesh_device_->mesh_command_queue(), buf, src_vec, MeshCoordinate(logical_y, logical_x));
}
}

for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
std::vector<uint32_t> dst_vec = {};
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, Coordinate(logical_y, logical_x));
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, buf, MeshCoordinate(logical_y, logical_x));
EXPECT_EQ(dst_vec, src_vec);
}
}
Expand Down
42 changes: 42 additions & 0 deletions tests/tt_metal/distributed/test_mesh_coord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace {

using ::testing::ElementsAre;
using ::testing::UnorderedElementsAre;

TEST(SimpleMeshShapeTest, Construction) {
SimpleMeshShape shape_1d(3);
EXPECT_EQ(shape_1d.dims(), 1);
Expand Down Expand Up @@ -172,6 +173,31 @@ TEST(MeshCoordinateRangeTest, SubrangeOneElement) {
EXPECT_THAT(coords, ElementsAre(MeshCoordinate(1, 1, 1)));
}

TEST(MeshCoordinateRangeTest, Contains) {
MeshCoordinateRange range(MeshCoordinate(1, 1, 3), MeshCoordinate(1, 1, 3));
EXPECT_TRUE(range.contains(MeshCoordinate(1, 1, 3)));

range = MeshCoordinateRange(MeshCoordinate(0, 2), MeshCoordinate(1, 2));
EXPECT_TRUE(range.contains(MeshCoordinate(0, 2)));
EXPECT_TRUE(range.contains(MeshCoordinate(1, 2)));
EXPECT_FALSE(range.contains(MeshCoordinate(0, 1)));
EXPECT_FALSE(range.contains(MeshCoordinate(2, 1)));
EXPECT_FALSE(range.contains(MeshCoordinate(2, 2)));
}

TEST(MeshCoordinateRangeTest, Dimensionality) {
EXPECT_EQ(MeshCoordinateRange(MeshCoordinate(0), MeshCoordinate(5)).dims(), 1);
EXPECT_EQ(MeshCoordinateRange(MeshCoordinate(0, 1), MeshCoordinate(5, 1)).dims(), 2);
EXPECT_EQ(MeshCoordinateRange(MeshCoordinate(0, 1, 2), MeshCoordinate(5, 1, 2)).dims(), 3);
}

TEST(MeshCoordinateRangeTest, ContainsMismatchedDimensions) {
MeshCoordinateRange range(MeshCoordinate(1, 1, 3), MeshCoordinate(1, 1, 3));

EXPECT_EQ(range.dims(), 3);
EXPECT_ANY_THROW(range.contains(MeshCoordinate(1, 1)));
}

TEST(MeshCoordinateRangeTest, MismatchedDimensions) {
MeshCoordinate start(1, 0);
MeshCoordinate end(2, 3, 1);
Expand Down Expand Up @@ -221,6 +247,22 @@ TEST(MeshContainerTest, InitialValues) {
EXPECT_THAT(initial_values, ElementsAre(3, 3, 3, 3, 3, 3));
}

TEST(MeshContainerTest, FromVector) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, std::vector<int>{0, 1, 2, 3, 4, 5});

std::vector<int> initial_values;
for (const auto& [_, value] : container) {
initial_values.push_back(value);
}
EXPECT_THAT(initial_values, ElementsAre(0, 1, 2, 3, 4, 5));
}

TEST(MeshContainerTest, FromVectorInvalidSize) {
SimpleMeshShape shape(2, 3);
EXPECT_ANY_THROW(MeshContainer<int>(shape, std::vector<int>{0, 1, 2, 3, 4}));
}

TEST(MeshContainerTest, ElementAccessRowMajor) {
SimpleMeshShape shape(2, 3);
MeshContainer<int> container(shape, 0);
Expand Down
23 changes: 16 additions & 7 deletions tests/tt_metal/distributed/test_mesh_events.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ TEST_F(MeshEventsTestSuite, ReplicatedAsyncIO) {
for (std::size_t logical_x = 0; logical_x < buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < buf->device()->num_rows(); logical_y++) {
readback_vecs.push_back({});
auto shard = buf->get_device_buffer(Coordinate(logical_y, logical_x));
auto shard = buf->get_device_buffer(MeshCoordinate(logical_y, logical_x));
ReadShard(
mesh_device_->mesh_command_queue(1), readback_vecs.back(), buf, Coordinate(logical_y, logical_x));
mesh_device_->mesh_command_queue(1),
readback_vecs.back(),
buf,
MeshCoordinate(logical_y, logical_x));
}
}

Expand Down Expand Up @@ -173,7 +176,7 @@ TEST_F(MeshEventsTestSuite, AsyncWorkloadAndIO) {
mesh_device_->mesh_command_queue(1),
dst_vec,
output_bufs[col_idx * worker_grid_size.y + row_idx],
Coordinate(logical_y, logical_x));
MeshCoordinate(logical_y, logical_x));
if (logical_y == 0) {
for (int i = 0; i < dst_vec.size(); i++) {
EXPECT_EQ(dst_vec[i].to_float(), (2 * iter + 5));
Expand Down Expand Up @@ -224,9 +227,12 @@ TEST_F(MeshEventsTestSuite, CustomDeviceRanges) {
for (std::size_t logical_x = devices_0.start_coord.x; logical_x < devices_0.end_coord.x; logical_x++) {
for (std::size_t logical_y = devices_0.start_coord.y; logical_y < devices_0.end_coord.y; logical_y++) {
readback_vecs.push_back({});
auto shard = buf->get_device_buffer(Coordinate(logical_y, logical_x));
auto shard = buf->get_device_buffer(MeshCoordinate(logical_y, logical_x));
ReadShard(
mesh_device_->mesh_command_queue(0), readback_vecs.back(), buf, Coordinate(logical_y, logical_x));
mesh_device_->mesh_command_queue(0),
readback_vecs.back(),
buf,
MeshCoordinate(logical_y, logical_x));
}
}

Expand All @@ -237,9 +243,12 @@ TEST_F(MeshEventsTestSuite, CustomDeviceRanges) {
for (std::size_t logical_x = devices_1.start_coord.x; logical_x < devices_1.end_coord.x; logical_x++) {
for (std::size_t logical_y = devices_1.start_coord.y; logical_y < devices_1.end_coord.y; logical_y++) {
readback_vecs.push_back({});
auto shard = buf->get_device_buffer(Coordinate(logical_y, logical_x));
auto shard = buf->get_device_buffer(MeshCoordinate(logical_y, logical_x));
ReadShard(
mesh_device_->mesh_command_queue(0), readback_vecs.back(), buf, Coordinate(logical_y, logical_x));
mesh_device_->mesh_command_queue(0),
readback_vecs.back(),
buf,
MeshCoordinate(logical_y, logical_x));
}
}
for (auto& vec : readback_vecs) {
Expand Down
3 changes: 2 additions & 1 deletion tests/tt_metal/distributed/test_mesh_sub_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ TEST_F(MeshSubDeviceTestSuite, DataCopyOnSubDevices) {
for (std::size_t logical_x = 0; logical_x < output_buf->device()->num_cols(); logical_x++) {
for (std::size_t logical_y = 0; logical_y < output_buf->device()->num_rows(); logical_y++) {
std::vector<uint32_t> dst_vec;
ReadShard(mesh_device_->mesh_command_queue(), dst_vec, output_buf, Coordinate(logical_y, logical_x));
ReadShard(
mesh_device_->mesh_command_queue(), dst_vec, output_buf, MeshCoordinate(logical_y, logical_x));
EXPECT_EQ(dst_vec, src_vec);
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/tt_metal/distributed/test_mesh_workload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ TEST_F(MeshWorkloadTestSuite, EltwiseBinaryMeshWorkload) {
mesh_device_->mesh_command_queue(),
dst_vec,
output_bufs[col_idx * worker_grid_size.y + row_idx],
Coordinate(logical_y, logical_x));
MeshCoordinate(logical_y, logical_x));
if (logical_y == 0) {
for (int i = 0; i < dst_vec.size(); i++) {
EXPECT_EQ(dst_vec[i].to_float(), 5);
Expand Down Expand Up @@ -687,7 +687,7 @@ TEST_F(MeshWorkloadTestSuite, MeshWorkloadSanity) {
mesh_device_->mesh_command_queue(),
dst_vec,
output_buffers[col_idx * worker_grid_size.y + row_idx],
Coordinate(logical_y, logical_x));
MeshCoordinate(logical_y, logical_x));
for (int i = 0; i < dst_vec.size(); i++) {
float ref_val = std::pow(2, (iter % 2) + 1);
if (i >= 512) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "eth_l1_address_map.h"

using tt::tt_metal::IDevice;
using tt::tt_metal::distributed::MeshCoordinate;
using tt::tt_metal::distributed::MeshDevice;
using tt::tt_metal::distributed::MeshDeviceConfig;
using tt::tt_metal::distributed::MeshDeviceView;
Expand Down Expand Up @@ -453,44 +454,44 @@ int main(int argc, char** argv) {
switch (n_hops) {
case 2:
return std::vector<IDevice*>{
view.get_device(0, 0),
view.get_device(0, 1),
view.get_device(MeshCoordinate(0, 0)),
view.get_device(MeshCoordinate(0, 1)),
};

case 4:
return std::vector<IDevice*>{
view.get_device(1, 1),
view.get_device(0, 1),
view.get_device(0, 2),
view.get_device(1, 2),
view.get_device(MeshCoordinate(1, 1)),
view.get_device(MeshCoordinate(0, 1)),
view.get_device(MeshCoordinate(0, 2)),
view.get_device(MeshCoordinate(1, 2)),
};

case 8:
return std::vector<IDevice*>{
view.get_device(1, 1),
view.get_device(1, 0),
view.get_device(0, 0),
view.get_device(0, 1),
view.get_device(0, 2),
view.get_device(0, 3),
view.get_device(1, 3),
view.get_device(1, 2),
view.get_device(MeshCoordinate(1, 1)),
view.get_device(MeshCoordinate(1, 0)),
view.get_device(MeshCoordinate(0, 0)),
view.get_device(MeshCoordinate(0, 1)),
view.get_device(MeshCoordinate(0, 2)),
view.get_device(MeshCoordinate(0, 3)),
view.get_device(MeshCoordinate(1, 3)),
view.get_device(MeshCoordinate(1, 2)),
};

case 12: // Does an extra loop through the inner ring
return std::vector<IDevice*>{
view.get_device(1, 1),
view.get_device(1, 0),
view.get_device(0, 0),
view.get_device(0, 1),
view.get_device(0, 2),
view.get_device(1, 2),
view.get_device(1, 1),
view.get_device(0, 1),
view.get_device(0, 2),
view.get_device(0, 3),
view.get_device(1, 3),
view.get_device(1, 2),
view.get_device(MeshCoordinate(1, 1)),
view.get_device(MeshCoordinate(1, 0)),
view.get_device(MeshCoordinate(0, 0)),
view.get_device(MeshCoordinate(0, 1)),
view.get_device(MeshCoordinate(0, 2)),
view.get_device(MeshCoordinate(1, 2)),
view.get_device(MeshCoordinate(1, 1)),
view.get_device(MeshCoordinate(0, 1)),
view.get_device(MeshCoordinate(0, 2)),
view.get_device(MeshCoordinate(0, 3)),
view.get_device(MeshCoordinate(1, 3)),
view.get_device(MeshCoordinate(1, 2)),
};

default: TT_THROW("Unsupported hop_count"); return std::vector<IDevice*>{};
Expand Down
22 changes: 22 additions & 0 deletions tests/ttnn/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@

#include <gtest/gtest.h>

#include <tt-metalium/mesh_coord.hpp>

#include <ttnn/core.hpp>
#include <ttnn/distributed/api.hpp>

namespace ttnn::distributed::test {

using ::tt::tt_metal::distributed::MeshContainer;

class DistributedTest : public ::testing::Test {
protected:
void SetUp() override {}
Expand Down Expand Up @@ -46,4 +50,22 @@ TEST_F(DistributedTest, TestNumDramChannels) {
EXPECT_EQ(mesh->num_dram_channels(), 96); // 8 devices * 12 channels
}

TEST_F(DistributedTest, ViewIs2D) {
auto mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);
std::vector<IDevice*> devices = mesh->get_devices();

MeshContainer<IDevice*> container_1d(SimpleMeshShape(8), devices);
MeshDeviceView view_1d(container_1d);
EXPECT_FALSE(view_1d.is_mesh_2d());

MeshContainer<IDevice*> container_2d(SimpleMeshShape(2, 4), devices);
MeshDeviceView view_2d(container_2d);
EXPECT_TRUE(view_2d.is_mesh_2d());

MeshContainer<IDevice*> container_3d(SimpleMeshShape(2, 2, 2), devices);
MeshDeviceView view_3d(container_3d);
EXPECT_FALSE(view_3d.is_mesh_2d());
}

} // namespace ttnn::distributed::test
Loading

0 comments on commit f4719c7

Please sign in to comment.