Skip to content

Commit

Permalink
#0: Fix failing Llama TG tests by preserving old behavior for ShardTe…
Browse files Browse the repository at this point in the history
…nsorToMesh

Previously, when we had a MxN MeshDevice, a mesh_mapper of
ShardTensorToMesh would behave differently based on whether `mesh_type`
passed into the MeshDevice was MeshType::RowMajor, MeshType::Ring.

With the removal of `MeshType` from MeshDevice specification, this
changed the default behavior for users constructing a MeshDevice
with default mesh_type=MeshType::RowMajor. This change now preserves the
old behavior so that shards are distributed in row-major instead of a
line.
  • Loading branch information
cfjchu committed Feb 7, 2025
1 parent e7e86d7 commit deba0ad
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 1 deletion.
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
**updated_device_params,
offset=ttnn.MeshOffset(0, 1),
)
mesh_device.reshape(ttnn.MeshShape(1, 4))

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
yield mesh_device
Expand Down
1 change: 0 additions & 1 deletion ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ std::vector<IDevice*> get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_
[&](const ShardTensor2D& s) {
return mesh_device.get_view().get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x});
},
[&](const ShardTensor& s) { return get_workers_for_tensor(mesh_device.get_view().get_line_devices()); },
[&](const auto&) { return get_workers_for_tensor(mesh_device.get_devices()); }},
host_storage.strategy);
} else if (std::holds_alternative<MultiDeviceStorage>(tensor.get_storage())) {
Expand Down

0 comments on commit deba0ad

Please sign in to comment.