diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp index e9243c91a17..4e667b33727 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp @@ -5,7 +5,11 @@ #include #include +#include "ttnn/distributed/api.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/tensor/tensor_impl_wrapper.hpp" #include "ttnn_test_fixtures.hpp" #include #include @@ -13,6 +17,9 @@ namespace ttnn::distributed::test { namespace { +using ::testing::FloatEq; +using ::testing::Pointwise; + using MeshTensorTest = T3kMultiDeviceFixture; TEST_F(MeshTensorTest, Lifecycle) { @@ -43,5 +50,97 @@ TEST_F(MeshTensorTest, Lifecycle) { EXPECT_FALSE(input_tensor.is_allocated()); } +using MeshTensorDeviceTest = T3kMultiDeviceFixture; + +TEST_F(MeshTensorDeviceTest, ToHostNonMeshTensor) { + const ttnn::Shape shape{1, 1, 32, 32}; + const TensorSpec tensor_spec = + TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); + Tensor input_host_tensor = Tensor::from_vector(std::vector(shape.volume()), tensor_spec); + EXPECT_TRUE(input_host_tensor.storage_type() == StorageType::OWNED); + + EXPECT_ANY_THROW(tensor_impl::to_host_mesh_tensor_wrapper(input_host_tensor)); +} + +TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) { + const ttnn::Shape shape{1, 1, 32, 32}; + const TensorSpec tensor_spec = + TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); + + std::vector host_data(shape.volume()); + std::iota(host_data.begin(), host_data.end(), 0); + + // Prepare host tensor to offload on device. + Tensor input_host_tensor = Tensor::from_vector(host_data, tensor_spec); + EXPECT_TRUE(input_host_tensor.storage_type() == StorageType::OWNED); + EXPECT_EQ(input_host_tensor.get_tensor_spec().logical_shape(), shape); + + // Write host tensor to device. + Tensor device_tensor = + tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor, mesh_device_.get(), MemoryConfig{}); + EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor)); + EXPECT_EQ(device_tensor.get_tensor_spec().logical_shape(), shape); + + auto* multi_device_storage = std::get_if(&device_tensor.get_storage()); + ASSERT_NE(multi_device_storage, nullptr); + for (const auto& [_, shard_spec] : multi_device_storage->specs) { + EXPECT_EQ(shard_spec.logical_shape(), shape); + } + EXPECT_TRUE(std::holds_alternative(multi_device_storage->strategy)); + + // Read the tensor back, and compare it with input data. + Tensor output_host_tensor = tensor_impl::to_host_mesh_tensor_wrapper(device_tensor); + EXPECT_TRUE(output_host_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST); + EXPECT_EQ(output_host_tensor.get_tensor_spec().logical_shape(), shape); + + for (const auto& tensor : get_tensors_from_multi_device_storage(output_host_tensor)) { + EXPECT_EQ(tensor.get_tensor_spec().logical_shape(), shape); + EXPECT_THAT(tensor.to_vector(), Pointwise(FloatEq(), host_data)); + } +} + +TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) { + const int num_devices = mesh_device_->num_devices(); + ASSERT_EQ(num_devices, 8); + // Test uneven shard shapes. + const ttnn::Shape shape{1, 9, 32, 32}; + const TensorSpec tensor_spec = + TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); + + std::vector host_data(shape.volume()); + std::iota(host_data.begin(), host_data.end(), 0); + + // Prepare multi-device host tensor to offload on device. + Tensor input_host_tensor_sharded = distribute_tensor( + Tensor::from_vector(host_data, tensor_spec), *shard_tensor_to_mesh_mapper(*mesh_device_, /*dim=*/1)); + EXPECT_TRUE(input_host_tensor_sharded.storage_type() == StorageType::MULTI_DEVICE_HOST); + + auto* multi_device_host_storage = + std::get_if(&input_host_tensor_sharded.get_storage()); + ASSERT_NE(multi_device_host_storage, nullptr); + const auto* strategy = std::get_if(&multi_device_host_storage->strategy); + ASSERT_NE(strategy, nullptr); + EXPECT_EQ(strategy->shard_dimension, 1); + + // Write host tensor to device. + Tensor device_tensor = + tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor_sharded, mesh_device_.get(), MemoryConfig{}); + EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor)); + + auto* multi_device_storage = std::get_if(&device_tensor.get_storage()); + ASSERT_NE(multi_device_storage, nullptr); + const auto* device_tensor_strategy = std::get_if(&multi_device_storage->strategy); + ASSERT_NE(device_tensor_strategy, nullptr); + EXPECT_EQ(device_tensor_strategy->shard_dimension, 1); + + // Read the tensor back, and compare it with input data. + Tensor output_host_tensor = aggregate_tensor( + tensor_impl::to_host_mesh_tensor_wrapper(device_tensor), *concat_mesh_to_tensor_composer(/*dim=*/1)); + EXPECT_TRUE(output_host_tensor.storage_type() == StorageType::OWNED); + EXPECT_EQ(output_host_tensor.get_tensor_spec().logical_shape(), shape); + + EXPECT_THAT(output_host_tensor.to_vector(), Pointwise(FloatEq(), host_data)); +} + } // namespace } // namespace ttnn::distributed::test diff --git a/tt_metal/api/tt-metalium/distributed.hpp b/tt_metal/api/tt-metalium/distributed.hpp index a94cbaa9ecc..96b3a23ed10 100644 --- a/tt_metal/api/tt-metalium/distributed.hpp +++ b/tt_metal/api/tt-metalium/distributed.hpp @@ -31,7 +31,12 @@ void WriteShard( std::vector& src, const Coordinate& coord, bool blocking = false) { - mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking); + std::vector shard_data_transfers = {{ + .shard_coord = coord, + .host_data = src.data(), + .region = std::nullopt, + }}; + mesh_cq.enqueue_write_shards(mesh_buffer, shard_data_transfers, blocking); } template @@ -43,7 +48,12 @@ void ReadShard( bool blocking = true) { auto shard = mesh_buffer->get_device_buffer(coord); dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType)); - mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking); + std::vector shard_data_transfers = {{ + .shard_coord = coord, + .host_data = dst.data(), + .region = std::nullopt, + }}; + mesh_cq.enqueue_read_shards(shard_data_transfers, mesh_buffer, blocking); } template diff --git a/tt_metal/api/tt-metalium/mesh_command_queue.hpp b/tt_metal/api/tt-metalium/mesh_command_queue.hpp index 38d13891095..61263207b9c 100644 --- a/tt_metal/api/tt-metalium/mesh_command_queue.hpp +++ b/tt_metal/api/tt-metalium/mesh_command_queue.hpp @@ -4,6 +4,8 @@ #pragma once +#include +#include "buffer.hpp" #include "command_queue_interface.hpp" #include "mesh_buffer.hpp" #include "mesh_device.hpp" @@ -21,20 +23,28 @@ class MeshCommandQueue { void populate_dispatch_core_type(); CoreCoord virtual_program_dispatch_core() const; CoreType dispatch_core_type() const; + // Helper functions for reading and writing individual shards void write_shard_to_device( - std::shared_ptr& shard_view, const void* src, tt::stl::Span sub_device_ids = {}); + std::shared_ptr& shard_view, + const void* src, + const BufferRegion& region, + tt::stl::Span sub_device_ids = {}); void read_shard_from_device( - std::shared_ptr& shard_view, void* dst, tt::stl::Span sub_device_ids = {}); + std::shared_ptr& shard_view, + void* dst, + const BufferRegion& region, + tt::stl::Span sub_device_ids = {}); + // Helper functions for read and write entire Sharded-MeshBuffers void write_sharded_buffer(const MeshBuffer& buffer, const void* src); void read_sharded_buffer(MeshBuffer& buffer, void* dst); std::array config_buffer_mgr_; std::array expected_num_workers_completed_; - MeshDevice* mesh_device_; - uint32_t id_; + MeshDevice* mesh_device_ = nullptr; + uint32_t id_ = 0; CoreCoord dispatch_core_; - CoreType dispatch_core_type_; + CoreType dispatch_core_type_ = CoreType::WORKER; public: MeshCommandQueue(MeshDevice* mesh_device, uint32_t id); @@ -42,16 +52,30 @@ class MeshCommandQueue { uint32_t id() const { return id_; } WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index) { return config_buffer_mgr_[index]; }; void enqueue_mesh_workload(MeshWorkload& mesh_workload, bool blocking); + + // Specifies host data to be written to or read from a MeshBuffer shard. + struct ShardDataTransfer { + Coordinate shard_coord; + void* host_data = nullptr; + std::optional region; + }; + // MeshBuffer Write APIs - void enqueue_write_shard( - std::shared_ptr& mesh_buffer, const void* host_data, const Coordinate& coord, bool blocking); void enqueue_write_shard_to_sub_grid( const MeshBuffer& buffer, const void* host_data, const LogicalDeviceRange& device_range, bool blocking); void enqueue_write_mesh_buffer(const std::shared_ptr& buffer, const void* host_data, bool blocking); + void enqueue_write_shards( + const std::shared_ptr& mesh_buffer, + const std::vector& shard_data_transfers, + bool blocking); + // MeshBuffer Read APIs - void enqueue_read_shard( - void* host_data, const std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking); void enqueue_read_mesh_buffer(void* host_data, const std::shared_ptr& buffer, bool blocking); + void enqueue_read_shards( + const std::vector& shard_data_transfers, + const std::shared_ptr& mesh_buffer, + bool blocking); + void finish(); void reset_worker_state( bool reset_launch_msg_state, diff --git a/tt_metal/distributed/mesh_command_queue.cpp b/tt_metal/distributed/mesh_command_queue.cpp index cb409bdb4eb..89eaaff1b03 100644 --- a/tt_metal/distributed/mesh_command_queue.cpp +++ b/tt_metal/distributed/mesh_command_queue.cpp @@ -4,8 +4,10 @@ #include #include +#include #include +#include "buffer.hpp" #include "tt_metal/distributed/mesh_workload_utils.hpp" #include "tt_metal/impl/buffers/dispatch.hpp" #include "tt_metal/impl/program/dispatch.hpp" @@ -164,16 +166,21 @@ void MeshCommandQueue::finish() { } void MeshCommandQueue::write_shard_to_device( - std::shared_ptr& shard_view, const void* src, tt::stl::Span sub_device_ids) { + std::shared_ptr& shard_view, + const void* src, + const BufferRegion& region, + tt::stl::Span sub_device_ids) { auto device = shard_view->device(); - BufferRegion region(0, shard_view->size()); sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids); buffer_dispatch::write_to_device_buffer( src, *shard_view, region, id_, expected_num_workers_completed_, this->dispatch_core_type(), sub_device_ids); } void MeshCommandQueue::read_shard_from_device( - std::shared_ptr& shard_view, void* dst, tt::stl::Span sub_device_ids) { + std::shared_ptr& shard_view, + void* dst, + const BufferRegion& region, + tt::stl::Span sub_device_ids) { auto device = shard_view->device(); chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); @@ -181,7 +188,6 @@ void MeshCommandQueue::read_shard_from_device( bool exit_condition = false; - BufferRegion region(0, shard_view->size()); if (is_sharded(shard_view->buffer_layout())) { auto dispatch_params = buffer_dispatch::initialize_sharded_buf_read_dispatch_params( *shard_view, id_, expected_num_workers_completed_, region); @@ -211,23 +217,6 @@ void MeshCommandQueue::read_shard_from_device( } } -void MeshCommandQueue::enqueue_write_shard( - std::shared_ptr& mesh_buffer, const void* host_data, const Coordinate& coord, bool blocking) { - auto shard = mesh_buffer->get_device_buffer(coord); - this->write_shard_to_device(shard, host_data); - - if (blocking) { - this->finish(); - } -} - -void MeshCommandQueue::enqueue_read_shard( - void* host_data, const std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking) { - TT_FATAL(blocking, "Only blocking reads are currently supported from MeshBuffer shards."); - auto shard = mesh_buffer->get_device_buffer(coord); - this->read_shard_from_device(shard, host_data); -} - void MeshCommandQueue::write_sharded_buffer(const MeshBuffer& buffer, const void* src) { auto global_buffer_shape = buffer.global_shard_spec().global_buffer_shape; auto global_buffer_size = buffer.global_shard_spec().global_size; @@ -269,26 +258,30 @@ void MeshCommandQueue::write_sharded_buffer(const MeshBuffer& buffer, const void replicated_device_y++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(replicated_device_y, replicated_device_x)); - this->write_shard_to_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, shard_data.data(), region); } } } else if (height_replicated or width_replicated) { if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) { for (auto replicated_device_y = 0; replicated_device_y < num_devices_y; replicated_device_y++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(replicated_device_y, device_x)); - this->write_shard_to_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, shard_data.data(), region); } device_x++; } else { for (auto replicated_device_x = 0; replicated_device_x < num_devices_x; replicated_device_x++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, replicated_device_x)); - this->write_shard_to_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, shard_data.data(), region); } device_y++; } } else { auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, device_x)); - this->write_shard_to_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, shard_data.data(), region); if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) { if (++device_x == num_devices_x) { device_x = 0; @@ -328,7 +321,9 @@ void MeshCommandQueue::read_sharded_buffer(MeshBuffer& buffer, void* dst) { for (std::size_t shard_y = 0; shard_y < num_shards_y; shard_y++) { for (std::size_t shard_x = 0; shard_x < num_shards_x; shard_x++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, device_x)); - this->read_shard_from_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->read_shard_from_device(device_shard_view, shard_data.data(), region); + uint32_t write_offset = shard_x * single_write_size + shard_y * stride_size_bytes * shard_shape.height(); uint32_t size_to_write = total_write_size_per_shard; uint32_t local_offset = 0; @@ -363,7 +358,8 @@ void MeshCommandQueue::enqueue_write_shard_to_sub_grid( for (std::size_t logical_y = device_range.start_coord.y; logical_y < device_range.end_coord.y + 1; logical_y++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(logical_y, logical_x)); - this->write_shard_to_device(device_shard_view, host_data); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, host_data, region); } } } else { @@ -387,6 +383,40 @@ void MeshCommandQueue::enqueue_read_mesh_buffer( this->read_sharded_buffer(*buffer, host_data); } +void MeshCommandQueue::enqueue_write_shards( + const std::shared_ptr& buffer, + const std::vector& shard_data_transfers, + bool blocking) { + // TODO: #17215 - this API is used by TTNN, as it currently implements rich ND sharding API for multi-devices. + // In the long run, the multi-device sharding API in Metal will change, and this will most likely be replaced. + for (const auto& shard_data_transfer : shard_data_transfers) { + auto device_shard_view = buffer->get_device_buffer(shard_data_transfer.shard_coord); + write_shard_to_device( + device_shard_view, + shard_data_transfer.host_data, + shard_data_transfer.region.value_or(BufferRegion(0, device_shard_view->size()))); + } + if (blocking) { + this->finish(); + } +} + +void MeshCommandQueue::enqueue_read_shards( + const std::vector& shard_data_transfers, + const std::shared_ptr& buffer, + bool blocking) { + // TODO: #17215 - this API is used by TTNN, as it currently implements rich ND sharding API for multi-devices. + // In the long run, the multi-device sharding API in Metal will change, and this will most likely be replaced. + const auto [num_rows, num_cols] = buffer->device()->shape(); + for (const auto& shard_data_transfer : shard_data_transfers) { + auto device_shard_view = buffer->get_device_buffer(shard_data_transfer.shard_coord); + read_shard_from_device( + device_shard_view, + shard_data_transfer.host_data, + shard_data_transfer.region.value_or(BufferRegion(0, device_shard_view->size()))); + } +} + void MeshCommandQueue::reset_worker_state( bool reset_launch_msg_state, uint32_t num_sub_devices, const vector_memcpy_aligned& go_signal_noc_data) { for (auto device : mesh_device_->get_devices()) { diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 34f77e9276e..831c1f4cbd5 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -116,7 +116,7 @@ Tensor aggregate_as_tensor( } } auto storage = - MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs, /*mesh_buffer_=*/nullptr}; + MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs, /*mesh_buffer=*/nullptr}; return Tensor(std::move(storage), reference_shard.get_tensor_spec()); } } @@ -211,6 +211,11 @@ bool is_multi_device_tensor(const Tensor& tensor) { tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; } +bool is_mesh_buffer_tensor(const Tensor& tensor) { + auto* multi_device_storage = std::get_if(&tensor.get_storage()); + return multi_device_storage != nullptr && multi_device_storage->mesh_buffer != nullptr; +} + std::vector get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor) { std::vector tensors; if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE) { @@ -263,7 +268,7 @@ Tensor create_multi_device_tensor( specs.insert({device_id, tensor.get_tensor_spec()}); } return Tensor{ - MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs, /*mesh_buffer_=*/nullptr}, + MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs, /*mesh_buffer=*/nullptr}, TensorSpec( tensors.at(0).get_logical_shape(), TensorLayout::fromPaddedShape( diff --git a/ttnn/cpp/ttnn/distributed/api.hpp b/ttnn/cpp/ttnn/distributed/api.hpp index 868aa553d73..da1758a16e2 100644 --- a/ttnn/cpp/ttnn/distributed/api.hpp +++ b/ttnn/cpp/ttnn/distributed/api.hpp @@ -45,6 +45,10 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) // Returns true has MultiDeviceHost/MultiDevice Storage bool is_multi_device_tensor(const Tensor& tensor); +// Returns true if tensor has MultiDevice storage type and is allocated on a mesh buffer. +// TODO: remove when the infrastructure uniformly works with mesh buffer backed tensors. +bool is_mesh_buffer_tensor(const Tensor& tensor); + // Given a multi-device tensor and a device, returns a list of per-device tensors. std::vector get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor); diff --git a/ttnn/cpp/ttnn/tensor/storage.cpp b/ttnn/cpp/ttnn/tensor/storage.cpp index ad385113ed8..e86cc45a2d5 100644 --- a/ttnn/cpp/ttnn/tensor/storage.cpp +++ b/ttnn/cpp/ttnn/tensor/storage.cpp @@ -16,4 +16,31 @@ std::vector> MultiDeviceStorage::get_buffers() const { return buf_vec; } +MultiDeviceStorage::MultiDeviceStorage( + const std::shared_ptr& mesh_buffer_, const TensorSpec& tensor_spec) : + strategy(ReplicateTensor{}), + mesh_buffer(mesh_buffer_) // +{ + // TODO: #17215 - In the long term, this code won't exist: no interactions will be made with individual Buffers, and + // instead the APIs will use MeshBuffer directly. MeshBuffer will also guarantee that all shards have the same + // tensor spec. + // + // For now, this code ensures MeshBuffer backed tensors are compatible with the rest of the ops infra. + const auto [num_rows, num_cols] = mesh_buffer->device()->shape(); + + ordered_device_ids.reserve(num_rows * num_cols); + buffers.reserve(num_rows * num_cols); + specs.reserve(num_rows * num_cols); + + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + auto buffer = mesh_buffer->get_device_buffer(distributed::Coordinate{row, col}); + const int device_id = buffer->device()->id(); + ordered_device_ids.push_back(device_id); + buffers.emplace(device_id, std::move(buffer)); + specs.emplace(device_id, tensor_spec); + } + } +} + } // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/tensor/storage.hpp b/ttnn/cpp/ttnn/tensor/storage.hpp index 16f3143edae..ebb7ced0226 100644 --- a/ttnn/cpp/ttnn/tensor/storage.hpp +++ b/ttnn/cpp/ttnn/tensor/storage.hpp @@ -4,6 +4,7 @@ #pragma once +#include #include "ttnn/tensor/types.hpp" #include "ttnn/tensor/tensor_spec.hpp" @@ -243,6 +244,7 @@ struct MultiDeviceStorage { swap(first.mesh_buffer, second.mesh_buffer); } + // Constructs a multi-device tensor backed by a collection of heterogeneous single-device buffers. MultiDeviceStorage( DistributedTensorConfig strategy_, std::vector ordered_device_ids_, @@ -255,6 +257,9 @@ struct MultiDeviceStorage { specs(std::move(specs_)), mesh_buffer(std::move(mesh_buffer_)) {} + // Constructs a replicated multi-device tensor backed by mesh buffer. + MultiDeviceStorage(const std::shared_ptr& mesh_buffer_, const TensorSpec& tensor_spec); + MultiDeviceStorage(MultiDeviceStorage&& other) { swap(*this, other); } MultiDeviceStorage(const MultiDeviceStorage& other) { @@ -378,6 +383,9 @@ struct MultiDeviceStorage { using Storage = std::variant; +template +concept OwnedOrBorrowedStorage = std::is_same_v || std::is_same_v; + template constexpr void raise_unsupported_storage() { static_assert(tt::stl::concepts::always_false_v, "Unsupported Storage"); diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index ed6e2da465c..14379d07267 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -1016,29 +1016,7 @@ Tensor allocate_tensor_on_mesh(const TensorSpec& tensor_spec, distributed::MeshD TT_FATAL( tt::tt_metal::detail::InMainThread(), "Allocation of a tensor on mesh must be called from the main thread"); auto mesh_buffer = tensor_impl::allocate_mesh_buffer_on_device(mesh_device, tensor_spec); - - const auto [num_rows, num_cols] = mesh_device->shape(); - std::vector ordered_device_ids; - std::unordered_map> buffers; - std::unordered_map specs; - - ordered_device_ids.reserve(num_rows * num_cols); - buffers.reserve(num_rows * num_cols); - specs.reserve(num_rows * num_cols); - - for (int row = 0; row < num_rows; ++row) { - for (int col = 0; col < num_cols; ++col) { - auto buffer = mesh_buffer->get_device_buffer(distributed::Coordinate{row, col}); - const int device_id = buffer->device()->id(); - ordered_device_ids.push_back(device_id); - buffers.emplace(device_id, std::move(buffer)); - specs.emplace(device_id, tensor_spec); - } - } - - MultiDeviceStorage multi_device_storage( - ReplicateTensor{}, std::move(ordered_device_ids), std::move(buffers), std::move(specs), std::move(mesh_buffer)); - + MultiDeviceStorage multi_device_storage(std::move(mesh_buffer), tensor_spec); return Tensor(std::move(multi_device_storage), tensor_spec); } diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index 1a45fc43960..da7d5e20e28 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -6,6 +6,11 @@ #include #include "tt-metalium/mesh_buffer.hpp" +#include "tt-metalium/mesh_device.hpp" +#include "tt-metalium/mesh_command_queue.hpp" +#include "tt-metalium/overloaded.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/tensor/storage.hpp" #include "ttnn/tensor/tensor_impl_wrapper.hpp" #include "ttnn/tensor/layout/tensor_layout.hpp" #include "ttnn/tensor/types.hpp" @@ -75,6 +80,9 @@ std::shared_ptr allocate_mesh_buffer_on_device( .buffer_layout = memory_config.memory_layout, .shard_parameters = tensor_spec.compute_shard_spec_buffer(), }; + + // Use replicated buffer, which supports both working with individual shards and replicating data across all shards. + // This is required for the time being, as TTNN has rich multi-device sharding implementation. const distributed::ReplicatedBufferConfig replicated_buffer_config{ .size = tensor_spec.compute_packed_buffer_size_bytes(), }; @@ -567,6 +575,66 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { return to_host(tensor, blocking, cq_id); } +template +Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking) { + TT_FATAL(ttnn::distributed::is_mesh_buffer_tensor(tensor), "Tensor is not a mesh buffer tensor!"); + TT_FATAL(tt::tt_metal::detail::InMainThread(), "to_host_mesh_tensor must be called from the main thread"); + const auto& storage = std::get(tensor.get_storage()); + const auto& mesh_buffer = storage.mesh_buffer; + ttnn::MeshDevice* device = mesh_buffer->device(); + distributed::MeshCommandQueue& mesh_cq = device->mesh_command_queue(); + const auto [num_rows, num_cols] = device->shape(); + const auto num_buffers = storage.buffers.size(); + + std::vector shard_data_transfers; + std::vector specs; + std::vector buffers; + specs.reserve(num_buffers); + buffers.reserve(num_buffers); + shard_data_transfers.reserve(num_buffers); + distributed::Coordinate shard_coord = {0, 0}; + for (int id : storage.ordered_device_ids) { + std::vector host_buffer; + const auto& shard_tensor_spec = storage.specs.at(id); + const auto tensor_size_bytes = shard_tensor_spec.compute_packed_buffer_size_bytes(); + host_buffer.resize(tensor_size_bytes / sizeof(T)); + specs.push_back(shard_tensor_spec); + buffers.push_back(owned_buffer::create(std::move(host_buffer))); + + shard_data_transfers.push_back(distributed::MeshCommandQueue::ShardDataTransfer{ + .shard_coord = shard_coord, + .host_data = std::visit([](auto& b) { return b.data(); }, buffers.back()), + .region = BufferRegion(0, tensor_size_bytes)}); + + if (++shard_coord.col == num_cols) { + shard_coord.col = 0; + ++shard_coord.row; + } + } + + mesh_cq.enqueue_read_shards(shard_data_transfers, mesh_buffer, /*blocking=*/true); + + MultiDeviceHostStorage host_storage(storage.strategy, std::move(buffers), std::move(specs)); + return Tensor(std::move(host_storage), tensor.get_tensor_spec()); +} + +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); + +template <> +Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking) { + return to_host_mesh_tensor(tensor, blocking); +} + +template <> +Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking) { + return to_host_mesh_tensor(tensor, blocking); +} + // ====================================================================================== // .to_device() details // ====================================================================================== @@ -613,9 +681,8 @@ template std::shared_ptr to_device_buffer( const Storage& storage, IDevice* device, const TensorSpec& tensor_spec, uint8_t cq_id) { return std::visit( - [&device, &tensor_spec, cq_id](auto&& storage) -> std::shared_ptr { - using StorageType = std::decay_t; - if constexpr (std::is_same_v or std::is_same_v) { + tt::stl::overloaded{ + [&device, &tensor_spec, cq_id](const StorageType& storage) { auto data_to_write = host_buffer::get_as(storage.buffer); auto expected_packed_buffer_size_bytes = tensor_spec.compute_packed_buffer_size_bytes(); auto input_size_bytes = data_to_write.size() * sizeof(T); @@ -625,16 +692,11 @@ std::shared_ptr to_device_buffer( input_size_bytes, expected_packed_buffer_size_bytes); return initialize_data_on_device(data_to_write, device, tensor_spec, cq_id); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage doesn't support to_device_buffer"); - } else if constexpr (std::is_same_v) { - TT_THROW("MultiHostStorage storage doesn't support to_device_buffer"); - } else if constexpr (std::is_same_v) { - TT_THROW("MultiDeviceStorage doesn't support to_device_buffer"); - } else { - raise_unsupported_storage(); - } - }, + }, + [](const auto& s) { + TT_THROW("Unexpected storage type {}", tt::stl::get_type_name(s)); + return std::shared_ptr(); + }}, storage); } @@ -645,9 +707,6 @@ std::shared_ptr to_device_buffer( template Tensor to_device(const Tensor& tensor, IDevice* target_device, const MemoryConfig& memory_config, uint8_t cq_id) { TT_FATAL(tensor.storage_type() != StorageType::DEVICE, "Tensor is already on device!"); - if (tensor.storage_type() == StorageType::OWNED) { - TT_FATAL(tensor.is_allocated(), "Need host buffer on device to exist to copy data to device!"); - } TT_FATAL(target_device != nullptr, "Need target device in order to move tensor to device!"); TT_FATAL(tensor.is_allocated(), "Need data to exist in order to move it to device"); @@ -682,6 +741,141 @@ Tensor to_device( return to_device(tensor, target_device, memory_config, cq_id); } +template +MultiDeviceStorage replicate_to_mesh_buffer( + const StorageType& storage, + distributed::MeshDevice* mesh_device, + const std::shared_ptr& mesh_buffer, + const TensorSpec& tensor_spec) { + auto data_to_write = host_buffer::get_as(storage.buffer); + const auto expected_packed_buffer_size_bytes = tensor_spec.compute_packed_buffer_size_bytes(); + const auto input_size_bytes = data_to_write.size() * sizeof(T); + TT_FATAL( + input_size_bytes == expected_packed_buffer_size_bytes, + "Host data with total size {}B does not match expected size {}B of device buffer!", + input_size_bytes, + expected_packed_buffer_size_bytes); + + mesh_device->mesh_command_queue().enqueue_write_mesh_buffer(mesh_buffer, data_to_write.data(), /*blocking=*/false); + return MultiDeviceStorage(mesh_buffer, tensor_spec); +} + +template +MultiDeviceStorage shard_to_mesh_buffer( + const MultiDeviceHostStorage& storage, + distributed::MeshDevice* mesh_device, + const std::shared_ptr& mesh_buffer, + const TensorSpec& tensor_spec) { + std::vector ordered_device_ids; + std::unordered_map> buffers; + std::unordered_map specs; + ordered_device_ids.reserve(storage.buffers.size()); + buffers.reserve(storage.buffers.size()); + specs.reserve(storage.buffers.size()); + + const auto [num_rows, num_cols] = mesh_device->shape(); + TT_FATAL( + storage.buffers.size() <= mesh_device->num_devices(), + "Number of host buffers {} exceeds the number of shards {}", + storage.buffers.size(), + mesh_device->num_devices()); + + std::vector shard_data_transfers; + shard_data_transfers.reserve(storage.buffers.size()); + distributed::Coordinate shard_coord = {0, 0}; + for (int i = 0; i < storage.buffers.size(); i++) { + TensorSpec shard_tensor_spec( + storage.specs[i].logical_shape(), + storage.specs[i].tensor_layout().with_memory_config(tensor_spec.memory_config())); + const auto& shard_host_buffer = storage.buffers[i]; + + const auto& shard_buffer = mesh_buffer->get_device_buffer(shard_coord); + ordered_device_ids.push_back(shard_buffer->device()->id()); + buffers.insert({shard_buffer->device()->id(), shard_buffer}); + specs.insert({shard_buffer->device()->id(), shard_tensor_spec}); + + auto data_to_write = host_buffer::get_as(shard_host_buffer); + const auto expected_packed_buffer_size_bytes = shard_tensor_spec.compute_packed_buffer_size_bytes(); + const auto input_size_bytes = data_to_write.size() * sizeof(T); + TT_FATAL( + input_size_bytes == expected_packed_buffer_size_bytes, + "Host data with total size {}B does not match expected size {}B of device buffer!", + input_size_bytes, + expected_packed_buffer_size_bytes); + TT_FATAL( + expected_packed_buffer_size_bytes <= tensor_spec.compute_packed_buffer_size_bytes(), + "Shard tensor size exceeds the global tensor size!"); + shard_data_transfers.push_back(distributed::MeshCommandQueue::ShardDataTransfer{ + .shard_coord = shard_coord, + .host_data = data_to_write.data(), + .region = BufferRegion(0, input_size_bytes)}); + if (++shard_coord.col == num_cols) { + shard_coord.col = 0; + ++shard_coord.row; + } + } + + mesh_device->mesh_command_queue().enqueue_write_shards(mesh_buffer, shard_data_transfers, /*blocking=*/false); + + return MultiDeviceStorage( + storage.strategy, std::move(ordered_device_ids), std::move(buffers), std::move(specs), mesh_buffer); +} + +template +Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* mesh_device, const MemoryConfig& memory_config) { + TT_FATAL(tt::tt_metal::detail::InMainThread(), "to_device_mesh_tensor must be called from the main thread"); + TT_FATAL(tensor.storage_type() != StorageType::MULTI_DEVICE, "Tensor is already on device!"); + TT_FATAL(mesh_device != nullptr, "Need target device in order to move tensor to device!"); + TT_FATAL(tensor.is_allocated(), "Need data to exist in order to move it to device"); + + TensorSpec tensor_spec( + tensor.get_logical_shape(), tensor.get_tensor_spec().tensor_layout().with_memory_config(memory_config)); + + auto mesh_buffer = allocate_mesh_buffer_on_device(mesh_device, tensor_spec); + MultiDeviceStorage mesh_storage = std::visit( + tt::stl::overloaded{ + [&mesh_device, &mesh_buffer, &tensor_spec](const StorageType& storage) { + // Replicate data across devices in a mesh. + return replicate_to_mesh_buffer(storage, mesh_device, mesh_buffer, tensor_spec); + }, + [&mesh_device, &mesh_buffer, &tensor_spec](const MultiDeviceHostStorage& storage) { + // Shard multi device host shards across devices in a mesh.. + return shard_to_mesh_buffer(storage, mesh_device, mesh_buffer, tensor_spec); + }, + [](const auto& s) -> MultiDeviceStorage { + TT_THROW("Unexpected storage type {}", tt::stl::get_type_name(s)); + }}, + tensor.get_storage()); + + return Tensor(std::move(mesh_storage), tensor_spec); +} + +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); + +template <> +Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config) { + return to_device_mesh_tensor(tensor, target_device, memory_config); +} + +template <> +Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config) { + return to_device_mesh_tensor(tensor, target_device, memory_config); +} + // ====================================================================================== // Helpers for converting between logical <-> physical data with full tensor spec // ====================================================================================== @@ -909,18 +1103,20 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { } }; + using RetType = std::variant; auto output_storage = std::visit( - [&convert, target_layout](auto&& storage) -> std::variant { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { + tt::stl::overloaded{ + [&convert, target_layout](const OwnedStorage& storage) -> RetType { const auto input_data = owned_buffer::get_as(storage.buffer); auto output_buffer = owned_buffer::create(std::move(convert(input_data))); return OwnedStorage{output_buffer}; - } else if constexpr (std::is_same_v) { + }, + [&convert, target_layout](const BorrowedStorage& storage) -> RetType { const auto input_data = borrowed_buffer::get_as(storage.buffer); auto output_buffer = owned_buffer::create(std::move(convert(input_data))); return OwnedStorage{output_buffer}; - } else if constexpr (std::is_same_v) { + }, + [&convert, target_layout](const MultiDeviceHostStorage& storage) -> RetType { std::vector output_buffers; std::vector output_specs; for (int i = 0; i < storage.num_buffers(); i++) { @@ -938,14 +1134,8 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { prev_spec.padded_shape()))); } return MultiDeviceHostStorage{storage.strategy, output_buffers, output_specs}; - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("On-device layout conversion for tensor with MultiDeviceStorage is not supported."); - } else { - raise_unsupported_storage(); - } - }, + }, + [](const auto& s) -> RetType { TT_THROW("Unsupported storage type {}", tt::stl::get_type_name(s)); }}, tensor.get_storage()); return std::visit( @@ -1078,24 +1268,14 @@ Tensor pad( }; auto output_buffer = std::visit( - [&pad](auto&& storage) -> owned_buffer::Buffer { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - const auto input_data = owned_buffer::get_as(storage.buffer); + tt::stl::overloaded{ + [&pad](const StorageType& storage) { + const auto input_data = host_buffer::get_as(storage.buffer); return pad(input_data); - } else if constexpr (std::is_same_v) { - const auto input_data = borrowed_buffer::get_as(storage.buffer); - return pad(input_data); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else { - raise_unsupported_storage(); - } - }, + }, + [](const auto& s) -> owned_buffer::Buffer { + TT_THROW("Unsupported storage type {}", tt::stl::get_type_name(s)); + }}, tensor.get_storage()); return Tensor( OwnedStorage{output_buffer}, @@ -1196,24 +1376,14 @@ Tensor unpad(const Tensor& tensor, const ttnn::Shape& output_tensor_start, const }; auto output_buffer = std::visit( - [&unpad](auto&& storage) -> owned_buffer::Buffer { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - const auto input_data = owned_buffer::get_as(storage.buffer); - return unpad(input_data); - } else if constexpr (std::is_same_v) { - const auto input_data = borrowed_buffer::get_as(storage.buffer); + tt::stl::overloaded{ + [&unpad](const StorageType& storage) { + const auto input_data = host_buffer::get_as(storage.buffer); return unpad(input_data); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else { - raise_unsupported_storage(); - } - }, + }, + [](const auto& s) -> owned_buffer::Buffer { + TT_THROW("Unsupported storage type {}", tt::stl::get_type_name(s)); + }}, tensor.get_storage()); return Tensor( OwnedStorage{output_buffer}, diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index 2602e0e4b2c..2a4654b8aac 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -8,6 +8,7 @@ #include #include +#include "tt-metalium/mesh_device.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor_utils.hpp" @@ -173,23 +174,27 @@ std::shared_ptr allocate_mesh_buffer_on_device( distributed::MeshDevice* mesh_device, const TensorSpec& tensor_spec); template -inline void read_data_from_device_buffer( +void read_data_from_device_buffer( CommandQueue& cq, std::shared_ptr device_buffer, void* host_buffer_data, bool blocking) { EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking); } template -inline void read_data_from_device_buffer(std::shared_ptr device_buffer, std::vector& host_buffer) { +void read_data_from_device_buffer(std::shared_ptr device_buffer, std::vector& host_buffer) { ::tt::tt_metal::detail::ReadFromBuffer(device_buffer, host_buffer); } // ====================================================================================== -// .to() +// .to_host() and .to_device() // ====================================================================================== template Tensor to_host(const Tensor& tensor, bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId); +// TODO: #17215 - This will eventually subsume `to_host`, when "mesh buffer" backed tensors become the default. +template +Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking = true); + template Tensor to_device( const Tensor& tensor, @@ -197,6 +202,15 @@ Tensor to_device( const MemoryConfig& memory_config, uint8_t cq_id = ttnn::DefaultQueueId); +// TODO: #17215 - This will eventually subsume `to_device`, when "mesh buffer" backed tensors become the default. +template +Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* mesh_device, const MemoryConfig& memory_config); + +// ====================================================================================== +// .to_layout() +// ====================================================================================== + template Tensor to_layout(const Tensor& tensor, Layout target_layout); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl_wrapper.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl_wrapper.hpp index 9cf4c810591..7bf2d8690d3 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl_wrapper.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl_wrapper.hpp @@ -38,8 +38,10 @@ inline size_t packed_buffer_size_bytes_wrapper(DataType dtype, size_t volume_unp } WRAP_FUNCTION(to_host) +WRAP_FUNCTION(to_host_mesh_tensor) WRAP_FUNCTION(extract_shard) WRAP_FUNCTION(to_device) +WRAP_FUNCTION(to_device_mesh_tensor) WRAP_FUNCTION(to_layout) WRAP_FUNCTION(pad) WRAP_FUNCTION(unpad)