Skip to content

Commit

Permalink
#0: Simplify MeshDevice construction by eradicating mesh_type
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Feb 4, 2025
1 parent 5e5bb17 commit d421f07
Show file tree
Hide file tree
Showing 29 changed files with 94 additions and 147 deletions.
4 changes: 1 addition & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
mesh_shape=ttnn.MeshShape(2, 2),
**updated_device_params,
offset=ttnn.MeshOffset(0, 1),
mesh_type=ttnn.MeshType.Ring,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down Expand Up @@ -303,9 +302,8 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device
request.node.pci_ids = ttnn.get_pcie_device_ids()
updated_device_params = get_updated_device_params(device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape=ttnn.MeshShape(2, 4),
mesh_shape=ttnn.MeshShape(1, 8),
**updated_device_params,
mesh_type=ttnn.MeshType.Ring,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down
2 changes: 1 addition & 1 deletion models/MODEL_HYBRID_TP_DP.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The main changes involve:

```python
# Work with submesh device as you would with a regular ttnn.MeshDevice
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4))
```

### 2. Compile & Run the Model on Each Submesh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel(
profiler.clear()

submesh_to_metadata = defaultdict(dict)
submeshes = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submeshes = mesh_device.create_submeshes((2, 4))
for submesh in submeshes:
# Set up model -----------------------------------------------------------------------
logger.info("Moving weights to devices; might take some time...")
Expand Down
2 changes: 1 addition & 1 deletion tech_reports/LLMs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ Below is a summary and example code of the most important concepts for mapping a
import ttnn

# 2x4 mesh_device, Topology Ring: devices are connected in a ring
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct initial torch tensor
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ Let's see an example of how to use the Ring All-Gather operation:
```py
import ttnn

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down Expand Up @@ -328,7 +328,7 @@ The result tensor for each device in the column is the concatenation in `dim=3`
```py
import ttnn

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1,1,32,256), dtype=torch.bfloat16)
Expand Down Expand Up @@ -534,7 +534,7 @@ torch_hidden_states = (torch.rand(batch_size, 1, sequence_length, config.hidden_
torch_output = model.forward(torch_hidden_states)

# Device Initialization
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4), mesh_type=ttnn.MeshType.Ring)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4))

# Initialize input activations on all devices in the mesh
# Alternatively, we can shard the input activations on the height dimension and
Expand Down Expand Up @@ -602,7 +602,7 @@ See `models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py::test_Llama_p
1. Submesh Creation

```py
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4))
```

2. Compile & Run the Model on Each Submesh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,6 @@ struct MeshConfig {

// Offset into Logical Device Coordinate Space
MeshOffset offset;

// TODO: consider whether this should be automatically inferred.
// Interpret as e.g. {Ring, Line}
MeshType type;
};

// Class exposing host and device dispatch state
Expand Down Expand Up @@ -986,8 +982,8 @@ Below, we include snippets from both the TT-Mesh and TT-Metal examples to illust
*Specify MeshConfig when creating two Virtual Meshes on a Physical Mesh.*

```cpp
MeshConfig mesh_config_0 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 0}, .type=mesh_type};
MeshConfig mesh_config_1 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 4}, .type=mesh_type};
MeshConfig mesh_config_0 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 0}};
MeshConfig mesh_config_1 = MeshConfig{.shape = virtual_mesh_shape, .offset = {0, 4}};

DeviceHandle virtual_mesh_0 = CreateMeshDevice(mesh_config_0, 2 /* num_command_queues */, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE);
DeviceHandle virtual_mesh_0 = CreateMeshDevice(mesh_config_1, 2 /* num_command_queues */, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE);
Expand Down
2 changes: 1 addition & 1 deletion tests/sweep_framework/sweeps/ccl/line_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def mesh_device_fixture():
assert ttnn.get_num_devices() >= 8, "Not T3000!"
device_ids = ttnn.get_t3k_physical_device_ids_ring()
num_devices_requested = len(device_ids)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), mesh_type=ttnn.MeshType.Line)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested))
print("ALL GATHER: Opened device mesh")

yield (mesh_device, "T3000 Mesh")
Expand Down
11 changes: 2 additions & 9 deletions tests/ttnn/distributed/test_distributed_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void check_test_environment() {

std::vector<chip_id_t> get_physical_device_ids(const MeshDevice& mesh) {
std::vector<chip_id_t> device_ids;
for (auto* device : mesh.get_devices(ttnn::distributed::MeshType::RowMajor)) {
for (auto* device : mesh.get_devices()) {
device_ids.push_back(device->id());
}
return device_ids;
Expand Down Expand Up @@ -138,12 +138,7 @@ TEST_F(T3000ReshapeTest, From1x8To2x4) {

TEST_F(T3000ReshapeTest, OnRingTopology) {
auto mesh = ttnn::distributed::open_mesh_device(
{1, 8},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
1,
tt::tt_metal::DispatchCoreType::WORKER,
ttnn::distributed::MeshType::Ring);
{1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

EXPECT_EQ(mesh->num_rows(), 1);
EXPECT_EQ(mesh->num_cols(), 8);
Expand Down Expand Up @@ -228,7 +223,6 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) {
// Fetch the device ids for a physically connected 2x2 mesh.
auto physical_device_ids = system_mesh.get_mapped_physical_device_ids(MeshDeviceConfig{
.mesh_shape = MeshShape{2, 2},
.mesh_type = ttnn::distributed::MeshType::Line,
});

// Supply the physical device ids to the mesh constructor that we know we know is 2x2 physically connected.
Expand All @@ -239,7 +233,6 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) {
DEFAULT_TRACE_REGION_SIZE,
1,
tt::tt_metal::DispatchCoreType::WORKER,
ttnn::distributed::MeshType::Line,
MeshOffset{0, 0},
physical_device_ids);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_tensor_parallel_falcon_mlp():

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 4),
mesh_type=ttnn.MeshType.Ring,
)

# Set PyTorch seed for reproducibility
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class T3kMultiDeviceFixture : public ::testing::Test {
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}, .mesh_type = MeshType::Ring});
mesh_device_ = MeshDevice::create(MeshDeviceConfig{.mesh_shape = MeshShape{2, 4}});
}

void TearDown() override {
Expand Down
16 changes: 8 additions & 8 deletions tests/ttnn/unit_tests/operations/ccl/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ def run_all_gather_impl(
logger.info(f"dim: {dim}")

input_tensor = torch.rand(input_shape).bfloat16()

input_tensors = torch.chunk(input_tensor, num_devices, dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
t = ttnn.from_torch(t, input_dtype, layout=layout, tile=ttnn.Tile(tile))
tt_input_tensors.append(t.to(mesh_device.get_devices()[i], mem_config))

input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)
input_tensor_mesh = ttnn.from_torch(
input_tensor,
dtype=input_dtype,
layout=layout,
tile=ttnn.Tile(tile),
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim),
device=mesh_device,
)
if trace_mode:
tt_out_tensor = run_with_trace(
mesh_device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def run_line_all_gather_instances(

for device in t3k_mesh_device.get_devices():
t3k_device.append(device)

t3k_device[4:] = t3k_device[::-1][:4]
t3000_device_rows = [
[t3k_device[4], t3k_device[0], t3k_device[3], t3k_device[7]],
[t3k_device[5], t3k_device[1], t3k_device[2], t3k_device[6]],
Expand Down
23 changes: 13 additions & 10 deletions tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,9 @@ def test_all_gather_sharded(
)
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [4])
@pytest.mark.parametrize("mesh_device", [pytest.param((2, 4), id="2x4_grid")], indirect=True)
def test_line_all_gather_async_on_T3K_cols_persistent_fabric_post_commit(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_output_shape,
dim,
Expand All @@ -510,10 +511,10 @@ def test_line_all_gather_async_on_T3K_cols_persistent_fabric_post_commit(
replication_factor,
num_iters=1,
):
if len(t3k_mesh_device.get_devices()) < 8:
if len(mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")
run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
Expand Down Expand Up @@ -563,8 +564,8 @@ def test_line_all_gather_async_on_T3K_cols_persistent_fabric_post_commit(
)
@pytest.mark.parametrize("replication_factor", [2])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("mesh_device", [pytest.param((2, 4), id="2x4_grid")], indirect=True)
def test_line_all_gather_async_on_T3K_rows_persistent_fabric_post_commit(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
dim,
Expand All @@ -574,14 +575,15 @@ def test_line_all_gather_async_on_T3K_rows_persistent_fabric_post_commit(
buffer_type,
use_program_cache,
function_level_defaults,
mesh_device,
enable_async,
replication_factor,
num_iters=1,
):
if len(t3k_mesh_device.get_devices()) < 8:
if len(mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")
run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
Expand Down Expand Up @@ -639,8 +641,9 @@ def test_line_all_gather_async_on_T3K_rows_persistent_fabric_post_commit(
],
)
@pytest.mark.parametrize("replication_factor2", [2])
@pytest.mark.parametrize("mesh_device", [pytest.param((2, 4), id="2x4_grid")], indirect=True)
def test_line_all_gather_async_on_T3K_back_to_back_cols_and_rows_persistent_fabric_post_commit(
t3k_mesh_device,
mesh_device,
num_devices1,
per_chip_output_shape1,
dim1,
Expand All @@ -660,10 +663,10 @@ def test_line_all_gather_async_on_T3K_back_to_back_cols_and_rows_persistent_fabr
replication_factor2,
num_iters=1,
):
if len(t3k_mesh_device.get_devices()) < 8:
if len(mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")
run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
mesh_device,
num_devices1,
per_chip_output_shape1,
ttnn.TensorMemoryLayout.INTERLEAVED,
Expand All @@ -685,7 +688,7 @@ def test_line_all_gather_async_on_T3K_back_to_back_cols_and_rows_persistent_fabr
)

run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
mesh_device,
num_devices2,
per_chip_output_shape2,
ttnn.TensorMemoryLayout.INTERLEAVED,
Expand Down
26 changes: 15 additions & 11 deletions tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,9 @@ def run_reduce_scatter_test(
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("trace_mode", [False])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 27648}], indirect=True)
@pytest.mark.parametrize("mesh_device", [pytest.param((2, 4), id="2x4_grid")], indirect=True)
def test_line_reduce_scatter_async_post_commit(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_output_shape,
dim,
Expand All @@ -365,7 +366,7 @@ def test_line_reduce_scatter_async_post_commit(
num_iters=16,
):
run_reduce_scatter_test(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_output_shape,
dim,
Expand Down Expand Up @@ -412,8 +413,9 @@ def test_line_reduce_scatter_async_post_commit(
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [4])
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("mesh_device", [pytest.param((2, 4), id="2x4_grid")], indirect=True)
def test_line_reduce_scatter_async_on_T3K_cols_post_commit(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_input_shape,
dim,
Expand All @@ -428,11 +430,11 @@ def test_line_reduce_scatter_async_on_T3K_cols_post_commit(
replication_factor,
num_iters=1,
):
if len(t3k_mesh_device.get_devices()) < 8:
if len(mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")

run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_input_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
Expand Down Expand Up @@ -480,8 +482,9 @@ def test_line_reduce_scatter_async_on_T3K_cols_post_commit(
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [2])
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("mesh_device", [pytest.param((2, 4), id="2x4_grid")], indirect=True)
def test_line_reduce_scatter_async_on_T3K_rows_post_commit(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_input_shape,
dim,
Expand All @@ -496,11 +499,11 @@ def test_line_reduce_scatter_async_on_T3K_rows_post_commit(
replication_factor,
num_iters=1,
):
if len(t3k_mesh_device.get_devices()) < 8:
if len(mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")

run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_input_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
Expand Down Expand Up @@ -602,8 +605,9 @@ def test_line_reduce_scatter_async_on_T3K_rows_post_commit(
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("enable_async", [False])
@pytest.mark.parametrize("replication_factor", [1])
@pytest.mark.parametrize("mesh_device", [pytest.param((2, 4), id="2x4_grid")], indirect=True)
def test_line_reduce_scatter_cluster_axis_on_T3K_width_sharded_reduce_scatter_post_commit(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_input_shape,
input_shard_shape,
Expand All @@ -623,7 +627,7 @@ def test_line_reduce_scatter_cluster_axis_on_T3K_width_sharded_reduce_scatter_po
num_iters=1,
trace_mode=False,
):
if len(t3k_mesh_device.get_devices()) < 8:
if len(mesh_device.get_devices()) < 8:
pytest.skip("Not T3K!")

input_shard_spec = ttnn.ShardSpec(
Expand All @@ -633,7 +637,7 @@ def test_line_reduce_scatter_cluster_axis_on_T3K_width_sharded_reduce_scatter_po
)

run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
t3k_mesh_device,
mesh_device,
num_devices,
per_chip_input_shape,
tensor_mem_layout,
Expand Down
Loading

0 comments on commit d421f07

Please sign in to comment.