Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17215: Add write/read APIs for TTNN tensors allocated on mesh buffer #17513

Merged
merged 6 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
#include <gtest/gtest.h>
#include <gmock/gmock.h>

#include "ttnn/distributed/api.hpp"
#include "ttnn/distributed/distributed_tensor_config.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_impl.hpp"
#include "ttnn/tensor/tensor_impl_wrapper.hpp"
#include "ttnn_test_fixtures.hpp"
#include <ttnn/distributed/types.hpp>
#include <ttnn/distributed/distributed_tensor.hpp>

namespace ttnn::distributed::test {
namespace {

using ::testing::FloatEq;
using ::testing::Pointwise;

using MeshTensorTest = T3kMultiDeviceFixture;

TEST_F(MeshTensorTest, Lifecycle) {
Expand Down Expand Up @@ -43,5 +50,97 @@ TEST_F(MeshTensorTest, Lifecycle) {
EXPECT_FALSE(input_tensor.is_allocated());
}

using MeshTensorDeviceTest = T3kMultiDeviceFixture;

TEST_F(MeshTensorDeviceTest, ToHostNonMeshTensor) {
const ttnn::Shape shape{1, 1, 32, 32};
const TensorSpec tensor_spec =
TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));
Tensor input_host_tensor = Tensor::from_vector(std::vector<float>(shape.volume()), tensor_spec);
EXPECT_TRUE(input_host_tensor.storage_type() == StorageType::OWNED);

EXPECT_ANY_THROW(tensor_impl::to_host_mesh_tensor_wrapper(input_host_tensor));
}

TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) {
const ttnn::Shape shape{1, 1, 32, 32};
const TensorSpec tensor_spec =
TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));

std::vector<float> host_data(shape.volume());
std::iota(host_data.begin(), host_data.end(), 0);

// Prepare host tensor to offload on device.
Tensor input_host_tensor = Tensor::from_vector(host_data, tensor_spec);
EXPECT_TRUE(input_host_tensor.storage_type() == StorageType::OWNED);
EXPECT_EQ(input_host_tensor.get_tensor_spec().logical_shape(), shape);

omilyutin-tt marked this conversation as resolved.
Show resolved Hide resolved
// Write host tensor to device.
Tensor device_tensor =
tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor, mesh_device_.get(), MemoryConfig{});
EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor));
EXPECT_EQ(device_tensor.get_tensor_spec().logical_shape(), shape);

auto* multi_device_storage = std::get_if<tt::tt_metal::MultiDeviceStorage>(&device_tensor.get_storage());
ASSERT_NE(multi_device_storage, nullptr);
for (const auto& [_, shard_spec] : multi_device_storage->specs) {
EXPECT_EQ(shard_spec.logical_shape(), shape);
}
EXPECT_TRUE(std::holds_alternative<tt::tt_metal::ReplicateTensor>(multi_device_storage->strategy));

// Read the tensor back, and compare it with input data.
Tensor output_host_tensor = tensor_impl::to_host_mesh_tensor_wrapper(device_tensor);
EXPECT_TRUE(output_host_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST);
EXPECT_EQ(output_host_tensor.get_tensor_spec().logical_shape(), shape);

for (const auto& tensor : get_tensors_from_multi_device_storage(output_host_tensor)) {
EXPECT_EQ(tensor.get_tensor_spec().logical_shape(), shape);
EXPECT_THAT(tensor.to_vector<float>(), Pointwise(FloatEq(), host_data));
}
}

TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) {
const int num_devices = mesh_device_->num_devices();
ASSERT_EQ(num_devices, 8);
// Test uneven shard shapes.
const ttnn::Shape shape{1, 9, 32, 32};
const TensorSpec tensor_spec =
TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));

std::vector<float> host_data(shape.volume());
std::iota(host_data.begin(), host_data.end(), 0);

// Prepare multi-device host tensor to offload on device.
Tensor input_host_tensor_sharded = distribute_tensor(
Tensor::from_vector(host_data, tensor_spec), *shard_tensor_to_mesh_mapper(*mesh_device_, /*dim=*/1));
EXPECT_TRUE(input_host_tensor_sharded.storage_type() == StorageType::MULTI_DEVICE_HOST);

auto* multi_device_host_storage =
std::get_if<tt::tt_metal::MultiDeviceHostStorage>(&input_host_tensor_sharded.get_storage());
ASSERT_NE(multi_device_host_storage, nullptr);
const auto* strategy = std::get_if<tt::tt_metal::ShardTensor>(&multi_device_host_storage->strategy);
ASSERT_NE(strategy, nullptr);
EXPECT_EQ(strategy->shard_dimension, 1);

// Write host tensor to device.
Tensor device_tensor =
tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor_sharded, mesh_device_.get(), MemoryConfig{});
EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor));

auto* multi_device_storage = std::get_if<tt::tt_metal::MultiDeviceStorage>(&device_tensor.get_storage());
ASSERT_NE(multi_device_storage, nullptr);
const auto* device_tensor_strategy = std::get_if<tt::tt_metal::ShardTensor>(&multi_device_storage->strategy);
ASSERT_NE(device_tensor_strategy, nullptr);
EXPECT_EQ(device_tensor_strategy->shard_dimension, 1);

// Read the tensor back, and compare it with input data.
Tensor output_host_tensor = aggregate_tensor(
tensor_impl::to_host_mesh_tensor_wrapper(device_tensor), *concat_mesh_to_tensor_composer(/*dim=*/1));
EXPECT_TRUE(output_host_tensor.storage_type() == StorageType::OWNED);
EXPECT_EQ(output_host_tensor.get_tensor_spec().logical_shape(), shape);

EXPECT_THAT(output_host_tensor.to_vector<float>(), Pointwise(FloatEq(), host_data));
}

} // namespace
} // namespace ttnn::distributed::test
14 changes: 12 additions & 2 deletions tt_metal/api/tt-metalium/distributed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ void WriteShard(
std::vector<DType>& src,
const Coordinate& coord,
bool blocking = false) {
mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking);
std::vector<MeshCommandQueue::ShardDataTransfer> shard_data_transfers = {{
.shard_coord = coord,
.host_data = src.data(),
.region = std::nullopt,
}};
mesh_cq.enqueue_write_shards(mesh_buffer, shard_data_transfers, blocking);
}

template <typename DType>
Expand All @@ -43,7 +48,12 @@ void ReadShard(
bool blocking = true) {
auto shard = mesh_buffer->get_device_buffer(coord);
dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType));
mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking);
std::vector<MeshCommandQueue::ShardDataTransfer> shard_data_transfers = {{
.shard_coord = coord,
.host_data = dst.data(),
.region = std::nullopt,
}};
mesh_cq.enqueue_read_shards(shard_data_transfers, mesh_buffer, blocking);
}

template <typename DType>
Expand Down
42 changes: 33 additions & 9 deletions tt_metal/api/tt-metalium/mesh_command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#pragma once

#include <optional>
#include "buffer.hpp"
#include "command_queue_interface.hpp"
#include "mesh_buffer.hpp"
#include "mesh_device.hpp"
Expand All @@ -21,37 +23,59 @@ class MeshCommandQueue {
void populate_dispatch_core_type();
CoreCoord virtual_program_dispatch_core() const;
CoreType dispatch_core_type() const;

// Helper functions for reading and writing individual shards
void write_shard_to_device(
std::shared_ptr<Buffer>& shard_view, const void* src, tt::stl::Span<const SubDeviceId> sub_device_ids = {});
std::shared_ptr<Buffer>& shard_view,
const void* src,
const BufferRegion& region,
tt::stl::Span<const SubDeviceId> sub_device_ids = {});
void read_shard_from_device(
std::shared_ptr<Buffer>& shard_view, void* dst, tt::stl::Span<const SubDeviceId> sub_device_ids = {});
std::shared_ptr<Buffer>& shard_view,
void* dst,
const BufferRegion& region,
tt::stl::Span<const SubDeviceId> sub_device_ids = {});

// Helper functions for read and write entire Sharded-MeshBuffers
void write_sharded_buffer(const MeshBuffer& buffer, const void* src);
void read_sharded_buffer(MeshBuffer& buffer, void* dst);
std::array<tt::tt_metal::WorkerConfigBufferMgr, DispatchSettings::DISPATCH_MESSAGE_ENTRIES> config_buffer_mgr_;
std::array<uint32_t, DispatchSettings::DISPATCH_MESSAGE_ENTRIES> expected_num_workers_completed_;
MeshDevice* mesh_device_;
uint32_t id_;
MeshDevice* mesh_device_ = nullptr;
uint32_t id_ = 0;
CoreCoord dispatch_core_;
CoreType dispatch_core_type_;
CoreType dispatch_core_type_ = CoreType::WORKER;

public:
MeshCommandQueue(MeshDevice* mesh_device, uint32_t id);
MeshDevice* device() const { return mesh_device_; }
uint32_t id() const { return id_; }
WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index) { return config_buffer_mgr_[index]; };
void enqueue_mesh_workload(MeshWorkload& mesh_workload, bool blocking);

// Specifies host data to be written to or read from a MeshBuffer shard.
struct ShardDataTransfer {
Coordinate shard_coord;
void* host_data = nullptr;
std::optional<BufferRegion> region;
};

// MeshBuffer Write APIs
void enqueue_write_shard(
std::shared_ptr<MeshBuffer>& mesh_buffer, const void* host_data, const Coordinate& coord, bool blocking);
void enqueue_write_shard_to_sub_grid(
const MeshBuffer& buffer, const void* host_data, const LogicalDeviceRange& device_range, bool blocking);
void enqueue_write_mesh_buffer(const std::shared_ptr<MeshBuffer>& buffer, const void* host_data, bool blocking);
void enqueue_write_shards(
const std::shared_ptr<MeshBuffer>& mesh_buffer,
const std::vector<ShardDataTransfer>& shard_data_transfers,
bool blocking);

// MeshBuffer Read APIs
void enqueue_read_shard(
void* host_data, const std::shared_ptr<MeshBuffer>& mesh_buffer, const Coordinate& coord, bool blocking);
void enqueue_read_mesh_buffer(void* host_data, const std::shared_ptr<MeshBuffer>& buffer, bool blocking);
void enqueue_read_shards(
const std::vector<ShardDataTransfer>& shard_data_transfers,
const std::shared_ptr<MeshBuffer>& mesh_buffer,
bool blocking);

void finish();
void reset_worker_state(
bool reset_launch_msg_state,
Expand Down
84 changes: 57 additions & 27 deletions tt_metal/distributed/mesh_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

#include <mesh_command_queue.hpp>
#include <mesh_device.hpp>
#include <optional>
#include <tt-metalium/dispatch_settings.hpp>

#include "buffer.hpp"
#include "tt_metal/distributed/mesh_workload_utils.hpp"
#include "tt_metal/impl/buffers/dispatch.hpp"
#include "tt_metal/impl/program/dispatch.hpp"
Expand Down Expand Up @@ -164,24 +166,28 @@ void MeshCommandQueue::finish() {
}

void MeshCommandQueue::write_shard_to_device(
std::shared_ptr<Buffer>& shard_view, const void* src, tt::stl::Span<const SubDeviceId> sub_device_ids) {
std::shared_ptr<Buffer>& shard_view,
const void* src,
const BufferRegion& region,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
auto device = shard_view->device();
BufferRegion region(0, shard_view->size());
sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids);
buffer_dispatch::write_to_device_buffer(
src, *shard_view, region, id_, expected_num_workers_completed_, this->dispatch_core_type(), sub_device_ids);
}

void MeshCommandQueue::read_shard_from_device(
std::shared_ptr<Buffer>& shard_view, void* dst, tt::stl::Span<const SubDeviceId> sub_device_ids) {
std::shared_ptr<Buffer>& shard_view,
void* dst,
const BufferRegion& region,
tt::stl::Span<const SubDeviceId> sub_device_ids) {
auto device = shard_view->device();
chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device->id());
uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id());
sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids);

bool exit_condition = false;

BufferRegion region(0, shard_view->size());
if (is_sharded(shard_view->buffer_layout())) {
auto dispatch_params = buffer_dispatch::initialize_sharded_buf_read_dispatch_params(
*shard_view, id_, expected_num_workers_completed_, region);
Expand Down Expand Up @@ -211,23 +217,6 @@ void MeshCommandQueue::read_shard_from_device(
}
}

void MeshCommandQueue::enqueue_write_shard(
std::shared_ptr<MeshBuffer>& mesh_buffer, const void* host_data, const Coordinate& coord, bool blocking) {
auto shard = mesh_buffer->get_device_buffer(coord);
this->write_shard_to_device(shard, host_data);

if (blocking) {
this->finish();
}
}

void MeshCommandQueue::enqueue_read_shard(
void* host_data, const std::shared_ptr<MeshBuffer>& mesh_buffer, const Coordinate& coord, bool blocking) {
TT_FATAL(blocking, "Only blocking reads are currently supported from MeshBuffer shards.");
auto shard = mesh_buffer->get_device_buffer(coord);
this->read_shard_from_device(shard, host_data);
}

void MeshCommandQueue::write_sharded_buffer(const MeshBuffer& buffer, const void* src) {
auto global_buffer_shape = buffer.global_shard_spec().global_buffer_shape;
auto global_buffer_size = buffer.global_shard_spec().global_size;
Expand Down Expand Up @@ -269,26 +258,30 @@ void MeshCommandQueue::write_sharded_buffer(const MeshBuffer& buffer, const void
replicated_device_y++) {
auto device_shard_view =
buffer.get_device_buffer(Coordinate(replicated_device_y, replicated_device_x));
this->write_shard_to_device(device_shard_view, shard_data.data());
const BufferRegion region(0, device_shard_view->size());
this->write_shard_to_device(device_shard_view, shard_data.data(), region);
}
}
} else if (height_replicated or width_replicated) {
if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) {
for (auto replicated_device_y = 0; replicated_device_y < num_devices_y; replicated_device_y++) {
auto device_shard_view = buffer.get_device_buffer(Coordinate(replicated_device_y, device_x));
this->write_shard_to_device(device_shard_view, shard_data.data());
const BufferRegion region(0, device_shard_view->size());
this->write_shard_to_device(device_shard_view, shard_data.data(), region);
}
device_x++;
} else {
for (auto replicated_device_x = 0; replicated_device_x < num_devices_x; replicated_device_x++) {
auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, replicated_device_x));
this->write_shard_to_device(device_shard_view, shard_data.data());
const BufferRegion region(0, device_shard_view->size());
this->write_shard_to_device(device_shard_view, shard_data.data(), region);
}
device_y++;
}
} else {
auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, device_x));
this->write_shard_to_device(device_shard_view, shard_data.data());
const BufferRegion region(0, device_shard_view->size());
this->write_shard_to_device(device_shard_view, shard_data.data(), region);
if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) {
if (++device_x == num_devices_x) {
device_x = 0;
Expand Down Expand Up @@ -328,7 +321,9 @@ void MeshCommandQueue::read_sharded_buffer(MeshBuffer& buffer, void* dst) {
for (std::size_t shard_y = 0; shard_y < num_shards_y; shard_y++) {
for (std::size_t shard_x = 0; shard_x < num_shards_x; shard_x++) {
auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, device_x));
this->read_shard_from_device(device_shard_view, shard_data.data());
const BufferRegion region(0, device_shard_view->size());
this->read_shard_from_device(device_shard_view, shard_data.data(), region);

uint32_t write_offset = shard_x * single_write_size + shard_y * stride_size_bytes * shard_shape.height();
uint32_t size_to_write = total_write_size_per_shard;
uint32_t local_offset = 0;
Expand Down Expand Up @@ -363,7 +358,8 @@ void MeshCommandQueue::enqueue_write_shard_to_sub_grid(
for (std::size_t logical_y = device_range.start_coord.y; logical_y < device_range.end_coord.y + 1;
logical_y++) {
auto device_shard_view = buffer.get_device_buffer(Coordinate(logical_y, logical_x));
this->write_shard_to_device(device_shard_view, host_data);
const BufferRegion region(0, device_shard_view->size());
this->write_shard_to_device(device_shard_view, host_data, region);
}
}
} else {
Expand All @@ -387,6 +383,40 @@ void MeshCommandQueue::enqueue_read_mesh_buffer(
this->read_sharded_buffer(*buffer, host_data);
}

void MeshCommandQueue::enqueue_write_shards(
const std::shared_ptr<MeshBuffer>& buffer,
const std::vector<ShardDataTransfer>& shard_data_transfers,
bool blocking) {
// TODO: #17215 - this API is used by TTNN, as it currently implements rich ND sharding API for multi-devices.
// In the long run, the multi-device sharding API in Metal will change, and this will most likely be replaced.
for (const auto& shard_data_transfer : shard_data_transfers) {
auto device_shard_view = buffer->get_device_buffer(shard_data_transfer.shard_coord);
write_shard_to_device(
device_shard_view,
shard_data_transfer.host_data,
shard_data_transfer.region.value_or(BufferRegion(0, device_shard_view->size())));
}
if (blocking) {
this->finish();
}
}

void MeshCommandQueue::enqueue_read_shards(
const std::vector<ShardDataTransfer>& shard_data_transfers,
const std::shared_ptr<MeshBuffer>& buffer,
bool blocking) {
// TODO: #17215 - this API is used by TTNN, as it currently implements rich ND sharding API for multi-devices.
// In the long run, the multi-device sharding API in Metal will change, and this will most likely be replaced.
const auto [num_rows, num_cols] = buffer->device()->shape();
omilyutin-tt marked this conversation as resolved.
Show resolved Hide resolved
for (const auto& shard_data_transfer : shard_data_transfers) {
auto device_shard_view = buffer->get_device_buffer(shard_data_transfer.shard_coord);
read_shard_from_device(
device_shard_view,
shard_data_transfer.host_data,
shard_data_transfer.region.value_or(BufferRegion(0, device_shard_view->size())));
}
}

omilyutin-tt marked this conversation as resolved.
Show resolved Hide resolved
void MeshCommandQueue::reset_worker_state(
bool reset_launch_msg_state, uint32_t num_sub_devices, const vector_memcpy_aligned<uint32_t>& go_signal_noc_data) {
for (auto device : mesh_device_->get_devices()) {
Expand Down
Loading
Loading