Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Simplify MeshDevice construction by eradicating mesh_type #17416

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,9 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic

updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 2),
mesh_shape=ttnn.MeshShape(1, num_pcie_devices_requested),
**updated_device_params,
offset=ttnn.MeshOffset(0, 1),
mesh_type=ttnn.MeshType.Ring,
physical_device_ids=device_ids[:num_pcie_devices_requested],
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down Expand Up @@ -305,7 +304,6 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 4),
**updated_device_params,
mesh_type=ttnn.MeshType.Ring,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down
2 changes: 1 addition & 1 deletion models/MODEL_HYBRID_TP_DP.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The main changes involve:

```python
# Work with submesh device as you would with a regular ttnn.MeshDevice
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4))
```

### 2. Compile & Run the Model on Each Submesh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel(
profiler.clear()

submesh_to_metadata = defaultdict(dict)
submeshes = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submeshes = mesh_device.create_submeshes((2, 4))
for submesh in submeshes:
# Set up model -----------------------------------------------------------------------
logger.info("Moving weights to devices; might take some time...")
Expand Down
2 changes: 1 addition & 1 deletion tech_reports/LLMs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ Below is a summary and example code of the most important concepts for mapping a
import ttnn

# 2x4 mesh_device, Topology Ring: devices are connected in a ring
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct initial torch tensor
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ Let's see an example of how to use the Ring All-Gather operation:
```py
import ttnn

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down Expand Up @@ -328,7 +328,7 @@ The result tensor for each device in the column is the concatenation in `dim=3`
```py
import ttnn

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down Expand Up @@ -534,7 +534,7 @@ torch_hidden_states = (torch.rand(batch_size, 1, sequence_length, config.hidden_
torch_output = model.forward(torch_hidden_states)

# Device Initialization
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4))

# Initialize input activations on all devices in the mesh
# Alternatively, we can shard the input activations on the height dimension and
Expand Down Expand Up @@ -602,7 +602,7 @@ See `models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py::test_Llama_p
1. Submesh Creation

```py
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4))
```

2. Compile & Run the Model on Each Submesh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,6 @@ struct MeshConfig {

// Offset into Logical Device Coordinate Space
MeshOffset offset;

// TODO: consider whether this should be automatically inferred.
// Interpret as e.g. {Ring, Line}
MeshType type;
};

// Class exposing host and device dispatch state
Expand Down Expand Up @@ -986,8 +982,8 @@ Below, we include snippets from both the TT-Mesh and TT-Metal examples to illust
*Specify MeshConfig when creating two Virtual Meshes on a Physical Mesh.*

```cpp
MeshConfig mesh_config_0 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 0}, .type=mesh_type};
MeshConfig mesh_config_1 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 4}, .type=mesh_type};
MeshConfig mesh_config_0 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 0}};
MeshConfig mesh_config_1 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 4}};

DeviceHandle virtual_mesh_0 = CreateMeshDevice(mesh_config_0, 2 /* num_command_queues */, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE);
DeviceHandle virtual_mesh_0 = CreateMeshDevice(mesh_config_1, 2 /* num_command_queues */, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE);
Expand Down
2 changes: 1 addition & 1 deletion tests/sweep_framework/sweeps/ccl/line_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def mesh_device_fixture():
assert ttnn.get_num_devices() >= 8, "Not T3000!"
device_ids = ttnn.get_t3k_physical_device_ids_ring()
num_devices_requested = len(device_ids)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), mesh_type=ttnn.MeshType.Line)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested))
print("ALL GATHER: Opened device mesh")

yield (mesh_device, "T3000 Mesh")
Expand Down
11 changes: 2 additions & 9 deletions tests/ttnn/distributed/test_distributed_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void check_test_environment() {

std::vector<chip_id_t> get_physical_device_ids(const MeshDevice& mesh) {
std::vector<chip_id_t> device_ids;
for (auto* device : mesh.get_devices(ttnn::distributed::MeshType::RowMajor)) {
for (auto* device : mesh.get_devices()) {
device_ids.push_back(device->id());
}
return device_ids;
Expand Down Expand Up @@ -138,12 +138,7 @@ TEST_F(T3000ReshapeTest, From1x8To2x4) {

TEST_F(T3000ReshapeTest, OnRingTopology) {
auto mesh = ttnn::distributed::open_mesh_device(
{1, 8},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
1,
tt::tt_metal::DispatchCoreType::WORKER,
ttnn::distributed::MeshType::Ring);
{1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

EXPECT_EQ(mesh->num_rows(), 1);
EXPECT_EQ(mesh->num_cols(), 8);
Expand Down Expand Up @@ -228,7 +223,6 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) {
// Fetch the device ids for a physically connected 2x2 mesh.
auto physical_device_ids = system_mesh.get_mapped_physical_device_ids(MeshDeviceConfig{
.mesh_shape = MeshShape{2, 2},
.mesh_type = ttnn::distributed::MeshType::Line,
});

// Supply the physical device ids to the mesh constructor that we know we know is 2x2 physically connected.
Expand All @@ -239,7 +233,6 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) {
DEFAULT_TRACE_REGION_SIZE,
1,
tt::tt_metal::DispatchCoreType::WORKER,
ttnn::distributed::MeshType::Line,
MeshOffset{0, 0},
physical_device_ids);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_tensor_parallel_falcon_mlp():

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 4),
mesh_type=ttnn.MeshType.Ring,
)

# Set PyTorch seed for reproducibility
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class T3kMultiDeviceFixture : public ::testing::Test {
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}, .mesh_type = MeshType::Ring});
mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}});
}

void TearDown() override {
Expand Down
12 changes: 9 additions & 3 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,22 +672,28 @@ def test_visualize_mesh_device(t3k_mesh_device):
ttnn.visualize_mesh_device(t3k_mesh_device)


def test_all_gather_multiple_submeshes(t3k_mesh_device):
def test_all_gather_multiple_submeshes():
"""Test all_gather with multiple submeshes"""

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))
ttnn.visualize_mesh_device(mesh_device)

def model(submesh):
full_tensor = torch.ones((1, 1, 32, 32 * submesh.get_num_devices()), dtype=torch.bfloat16)
for i in range(submesh.get_num_devices()):
full_tensor[..., i * 32 : (i + 1) * 32] = i

for device in submesh.get_devices():
print(device.id())

ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3))
ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1, topology=ttnn.Topology.Ring)

for device_tensor in ttnn.get_device_tensors(ttnn_tensor):
device_tensor_torch = ttnn.to_torch(device_tensor)
assert torch.all(device_tensor_torch == full_tensor)

submesh_devices = t3k_mesh_device.create_submeshes(ttnn.MeshShape(2, 2), ttnn.MeshType.Ring)
submesh_devices = mesh_device.create_submeshes(ttnn.MeshShape(2, 2))
for submesh in submesh_devices:
model(submesh)
3 changes: 1 addition & 2 deletions tt-train/sources/ttml/core/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ MeshDevice::MeshDevice(tt::tt_metal::distributed::MeshShape shape) :
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
/* num_command_queues*/ 1,
DispatchCoreConfig{},
ttnn::distributed::MeshType::RowMajor)) {
DispatchCoreConfig{})) {
assert(m_mesh_device);
}

Expand Down
2 changes: 0 additions & 2 deletions tt_metal/api/tt-metalium/mesh_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ struct MeshShape {
*
* - Line: Devices are arranged linearly in a single dimension.
*/
enum class MeshType { RowMajor, Ring, Line };

struct MeshDeviceConfig {
MeshShape mesh_shape{0, 0};
MeshOffset offset{0, 0};
std::vector<chip_id_t> physical_device_ids{};
MeshType mesh_type{MeshType::RowMajor};
};

} // namespace tt::tt_metal::distributed
13 changes: 4 additions & 9 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
std::shared_ptr<ScopedDevices> scoped_devices_;
MeshDeviceID mesh_id_;
MeshShape mesh_shape_;
MeshType type_;
std::unique_ptr<MeshDeviceView> view_;
std::vector<std::shared_ptr<MeshDevice>>
submeshes_; // Parent owns submeshes and is responsible for their destruction
Expand All @@ -71,7 +70,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
MeshDevice(
std::shared_ptr<ScopedDevices> mesh_handle,
const MeshShape& mesh_shape,
MeshType type,
std::weak_ptr<MeshDevice> parent_mesh = {});
~MeshDevice() override;

Expand Down Expand Up @@ -200,9 +198,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic

// A MeshDevice is a collection of devices arranged in a 2D grid.
// The type parameter allows the caller to specify how to linearize the devices in the mesh.
// If type is not provided, the default behavior is to return the devices based on the MeshType of the MeshDevice.

std::vector<IDevice*> get_devices(const std::optional<MeshType>& type = std::nullopt) const;
// Returns the devices in the mesh in row-major order.
std::vector<IDevice*> get_devices() const;
IDevice* get_device_index(size_t logical_device_id) const;
IDevice* get_device(chip_id_t physical_device_id) const;
IDevice* get_device(size_t row_idx, size_t col_idx) const;
Expand Down Expand Up @@ -238,12 +236,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
std::vector<std::shared_ptr<MeshDevice>> get_submeshes() const;

std::shared_ptr<MeshDevice> create_submesh(
const MeshShape& submesh_shape,
const MeshOffset& offset = MeshOffset{0, 0},
MeshType type = MeshType::RowMajor);
const MeshShape& submesh_shape, const MeshOffset& offset = MeshOffset{0, 0});

std::vector<std::shared_ptr<MeshDevice>> create_submeshes(
const MeshShape& submesh_shape, MeshType type = MeshType::RowMajor);
std::vector<std::shared_ptr<MeshDevice>> create_submeshes(const MeshShape& submesh_shape);

// These methods will get removed once in favour of the ones in IDevice* and TT-Mesh bringup
// These are prefixed with "mesh_" to avoid conflicts with the IDevice* methods
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/api/tt-metalium/mesh_device_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class MeshDeviceView {
// devices are returned in row-major order with start/end coordinates inclusive
[[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end) const;
[[nodiscard]] DeviceView get_devices(const MeshShape& submesh_shape) const;
[[nodiscard]] DeviceView get_devices(MeshType type = MeshType::RowMajor) const;
[[nodiscard]] DeviceView get_devices() const;

[[nodiscard]] DeviceView get_devices_on_row(size_t row) const;
[[nodiscard]] DeviceView get_devices_on_column(size_t col) const;
Expand Down
28 changes: 10 additions & 18 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,11 @@ uint32_t MeshDevice::dram_size_per_channel() const {
IDevice* MeshDevice::reference_device() const { return this->get_devices().at(0); }

MeshDevice::MeshDevice(
std::shared_ptr<ScopedDevices> mesh_handle,
const MeshShape& mesh_shape,
MeshType type,
std::weak_ptr<MeshDevice> parent_mesh) :
std::shared_ptr<ScopedDevices> mesh_handle, const MeshShape& mesh_shape, std::weak_ptr<MeshDevice> parent_mesh) :
scoped_devices_(std::move(mesh_handle)),
mesh_shape_(mesh_shape),
type_(type),
mesh_id_(generate_unique_mesh_id()),
parent_mesh_(std::move(parent_mesh))
{
parent_mesh_(std::move(parent_mesh)) {
work_executor_ = std::make_unique<WorkExecutor>(0 /* worker_core */, mesh_id_);
work_executor_->initialize();
work_executor_->set_worker_mode(WorkExecutorMode::SYNCHRONOUS);
Expand All @@ -142,16 +137,15 @@ std::shared_ptr<MeshDevice> MeshDevice::create(
const DispatchCoreConfig& dispatch_core_config,
tt::stl::Span<const std::uint32_t> l1_bank_remap) {
auto mesh_device = std::make_shared<MeshDevice>(
std::make_shared<ScopedDevices>(l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config),
config.mesh_shape,
config.mesh_type);
std::make_shared<ScopedDevices>(
l1_small_size, trace_region_size, num_command_queues, dispatch_core_config, config),
config.mesh_shape);

mesh_device->initialize(num_command_queues, l1_small_size, trace_region_size, l1_bank_remap);
return mesh_device;
}

std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
const MeshShape& submesh_shape, const MeshOffset& offset, MeshType type) {
std::shared_ptr<MeshDevice> MeshDevice::create_submesh(const MeshShape& submesh_shape, const MeshOffset& offset) {
if (submesh_shape.num_rows <= 0 || submesh_shape.num_cols <= 0) {
TT_THROW(
"Invalid submesh shape: ({}, {}). Both dimensions must be positive.",
Expand All @@ -175,7 +169,7 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
mesh_shape_.num_cols);
}

auto submesh = std::make_shared<MeshDevice>(scoped_devices_, submesh_shape, type, shared_from_this());
auto submesh = std::make_shared<MeshDevice>(scoped_devices_, submesh_shape, shared_from_this());
auto start_coordinate = Coordinate{offset.row, offset.col};
auto end_coordinate = Coordinate{offset.row + submesh_shape.num_rows - 1, offset.col + submesh_shape.num_cols - 1};

Expand All @@ -196,11 +190,11 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
return submesh;
}

std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(const MeshShape& submesh_shape, MeshType type) {
std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(const MeshShape& submesh_shape) {
std::vector<std::shared_ptr<MeshDevice>> submeshes;
for (int row = 0; row < this->num_rows(); row += submesh_shape.num_rows) {
for (int col = 0; col < this->num_cols(); col += submesh_shape.num_cols) {
auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}, type);
auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col});
submeshes.push_back(submesh);
}
}
Expand All @@ -224,9 +218,7 @@ IDevice* MeshDevice::get_device(chip_id_t physical_device_id) const {
TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id);
}

std::vector<IDevice*> MeshDevice::get_devices(const std::optional<MeshType>& requested_type) const {
return view_->get_devices(requested_type.value_or(type_));
}
std::vector<IDevice*> MeshDevice::get_devices() const { return view_->get_devices(); }

// TODO: Remove this function once we have a proper view interface
IDevice* MeshDevice::get_device(size_t row_idx, size_t col_idx) const {
Expand Down
11 changes: 2 additions & 9 deletions tt_metal/distributed/mesh_device_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ MeshDeviceView::MeshDeviceView(const std::vector<IDevice*>& devices, Coordinate
}

MeshDeviceView::MeshDeviceView(const MeshDevice& mesh_device) :
MeshDeviceView(mesh_device.get_devices(MeshType::RowMajor), mesh_device.shape()) {}
MeshDeviceView(mesh_device.get_devices(), mesh_device.shape()) {}

MeshDeviceView::MeshDeviceView(const std::vector<IDevice*>& devices, const MeshShape& shape) :
MeshDeviceView(devices, Coordinate{0, 0}, Coordinate{shape.num_rows - 1, shape.num_cols - 1}) {}
Expand Down Expand Up @@ -261,13 +261,6 @@ std::vector<IDevice*> MeshDeviceView::get_ring_devices() const {
return get_devices_from_coordinates(*this, boundary_coords);
}

MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) const {
switch (type) {
case MeshType::RowMajor: return this->devices_;
case MeshType::Ring: return this->get_ring_devices();
case MeshType::Line: return this->get_line_devices();
default: TT_THROW("Unsupported Mesh type: {}", type);
}
}
MeshDeviceView::DeviceView MeshDeviceView::get_devices() const { return this->devices_; }

} // namespace tt::tt_metal::distributed
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ std::shared_ptr<MeshDevice> open_mesh_device(
size_t trace_region_size,
size_t num_command_queues,
const DispatchCoreConfig& dispatch_core_config,
MeshType mesh_type,
const MeshOffset& offset,
const std::vector<int>& physical_device_ids) {
auto config = MeshDeviceConfig{
.mesh_shape = mesh_shape, .offset = offset, .physical_device_ids = physical_device_ids, .mesh_type = mesh_type};
auto config =
MeshDeviceConfig{.mesh_shape = mesh_shape, .offset = offset, .physical_device_ids = physical_device_ids};
return MeshDevice::create(config, l1_small_size, trace_region_size, num_command_queues, dispatch_core_config);
}

Expand Down Expand Up @@ -152,6 +151,7 @@ std::vector<IDevice*> get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_
[&](const ShardTensor2D& s) {
return mesh_device.get_view().get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x});
},
[&](const ShardTensor& s) { return mesh_device.get_view().get_line_devices(); },
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a bit more explanation to get_line_devices and get_ring_devices functions as documentation? It is not clear right away what are we those...

[&](const auto&) { return get_workers_for_tensor(); }},
host_storage.strategy);
} else if (std::holds_alternative<MultiDeviceStorage>(tensor.get_storage())) {
Expand Down
Loading
Loading