Skip to content

Commit

Permalink
Add resharding to post all gather layernorm/ rms norm op (#17156)
Browse files Browse the repository at this point in the history
### Problem description
For fast inference TG llama we need to fuse reshards with rms norm for
perf.

### What's changed
Extended layernorm / rms norm post all gather to take in an output
memory config and reshard to the respectiv specs. Specifically tested
are also the arbitrary core grids we have for prefetcher matmuls.

Perf: for TG llama shapes we have 1.2 us overhead for implicit
resharding (no overhead if no output memory config is passed in).
Explicit resharding takes 1.7 us.

### Checklist
- [x] Post commit CI passes:
-
https://github.com/tenstorrent/tt-metal/actions/runs/13012519756/job/36293983255
-
https://github.com/tenstorrent/tt-metal/actions/runs/13015080021/job/36302343886
- [ ] Blackhole Post commit (if applicable)
- [x] TG tests passing:
https://github.com/tenstorrent/tt-metal/actions/runs/12996050710
- [x] T3K tests passing:
https://github.com/tenstorrent/tt-metal/actions/runs/12996093660
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
johanna-rock-tt authored Jan 29, 2025
1 parent 25d77a7 commit 4034423
Show file tree
Hide file tree
Showing 10 changed files with 600 additions and 30 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tg-unit-tests-impl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
{ name: "TG Llama3-small unit tests", arch: wormhole_b0, model: llama3-small, timeout: 45, owner_id: U06F3ER8X9A}, # Stuti Raizada
{ name: "TG Llama3-70b unit tests", arch: wormhole_b0, model: llama3-70b, timeout: 45, owner_id: U06F3ER8X9A}, # Stuti Raizada
{ name: "TG DRAM Prefetcher unit tests", arch: wormhole_b0, model: prefetcher, timeout: 30, owner_id: U071CKL4AFK}, # Ammar Vora, Yu Gao
{ name: "TG distributed ops tests", arch: wormhole_b0, model: distributed-ops, timeout: 15, owner_id: U044T8U8DEF}, # Johanna Rock
]
name: ${{ matrix.test-group.name }}
env:
Expand Down
21 changes: 21 additions & 0 deletions tests/scripts/tg/run_tg_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,24 @@ run_tg_llama3.1-70b_tests() {
fi
}

run_tg_distributed_op_tests() {
# Record the start time
fail=0
start_time=$(date +%s)

echo "LOG_METAL: Running run_tg_distributed_op_tests"

pytest tests/ttnn/distributed/test_distributed_layernorm_TG.py ; fail+=$?

# Record the end time
end_time=$(date +%s)
duration=$((end_time - start_time))
echo "LOG_METAL: run_tg_distributed_op_tests $duration seconds to complete"
if [[ $fail -ne 0 ]]; then
exit 1
fi
}

run_tg_prefetcher_tests() {
# Record the start time
fail=0
Expand Down Expand Up @@ -109,6 +127,9 @@ run_tg_tests() {
elif [[ "$1" == "prefetcher" ]]; then
run_tg_prefetcher_tests

elif [[ "$1" == "distributed-ops" ]]; then
run_tg_distributed_op_tests

else
echo "LOG_METAL: Unknown model type: $1"
return 1
Expand Down
191 changes: 191 additions & 0 deletions tests/ttnn/distributed/test_distributed_layernorm_TG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
from loguru import logger
import ttnn
import pytest
import torch
import math
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_pcc,
)


def rms_norm(x, dim, gamma, beta, eps):
return x * torch.rsqrt(x.pow(2).mean([-i for i in range(1, len(dim) + 1)], keepdim=True) + eps) * gamma + beta


PREFETCHER_NOC1_GRID = [
(6, 6),
(6, 7),
(6, 9),
(6, 0),
(6, 1),
(6, 2),
(6, 4),
(6, 5),
(5, 5),
(5, 6),
(5, 7),
(5, 9),
(5, 0),
(5, 1),
(5, 2),
(5, 4),
(1, 4),
(1, 5),
(1, 9),
(1, 0),
(2, 0),
(2, 4),
(2, 5),
(2, 9),
]


@pytest.mark.parametrize(
"num_devices_fractured, input_dim, input_core_grid, output_core_grid",
[
(4, 8192, ttnn.CoreGrid(x=2, y=8), PREFETCHER_NOC1_GRID), # TG llama use case; 4 tiles per core input
(4, 8192, ttnn.CoreGrid(x=2, y=8), None),
],
)
@pytest.mark.parametrize("device_params", [{"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}], indirect=True)
@pytest.mark.parametrize(
"mesh_device",
[(8, 4)],
indirect=True,
)
def test_layernorm_perf(mesh_device, num_devices_fractured, input_dim, input_core_grid, output_core_grid):
torch.manual_seed(1234)

num_cores = input_core_grid.num_cores
dim = int(
math.ceil(input_dim / num_devices_fractured / num_cores / 32) * num_devices_fractured * num_cores * 32
) # padded
input_shape = (1, 1, 32, dim)
if isinstance(input_core_grid, ttnn.CoreGrid):
input_core_range_set = ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(input_core_grid.x - 1, input_core_grid.y - 1)),
]
)
else:
input_core_range_set = ttnn.CoreRangeSet(
[
ttnn.CoreRange(
ttnn.CoreCoord(x, y),
ttnn.CoreCoord(x, y),
)
for x, y in input_core_grid
]
)
size_per_device = dim // num_devices_fractured
# Input memory config
input_memory_config = ttnn.create_sharded_memory_config(
shape=(
input_shape[0] * input_shape[1] * input_shape[2],
input_shape[3] // num_devices_fractured // input_core_range_set.num_cores(),
),
core_grid=input_core_range_set,
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
# Create input tensor with input memory config
input_tensor_torch = torch.randn(input_shape)
gamma_torch = torch.randn((1, 1, 1, input_shape[3]))
input_tensor = ttnn.as_tensor(
input_tensor_torch,
dtype=ttnn.bfloat16,
device=mesh_device,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device=mesh_device, dims=(None, 3), mesh_shape=list(mesh_device.shape)),
layout=ttnn.TILE_LAYOUT,
memory_config=input_memory_config,
)
gamma_tensor = ttnn.as_tensor(
gamma_torch.reshape([1, 1, dim // 32, 32]),
device=mesh_device,
dtype=ttnn.bfloat16,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(None, 2), mesh_shape=list(mesh_device.shape)),
)
ln_prg_cfg = ttnn.LayerNormShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(input_core_grid.x, input_core_grid.y),
subblock_w=1,
block_h=1,
block_w=(size_per_device // num_cores) // 32,
inplace=False,
)
ln_sharded_stats_memcfg = ttnn.create_sharded_memory_config(
shape=[1, 1, 32, 32 * num_devices_fractured],
core_grid=ttnn.CoreGrid(y=1, x=1),
strategy=ttnn.ShardStrategy.WIDTH,
)
# Run distributed rmsnorm part 1
tt_stats = ttnn.rms_norm_pre_all_gather(input_tensor, program_config=ln_prg_cfg)

# All gather stats
tt_stats = ttnn.all_gather(
tt_stats,
3,
num_links=1,
cluster_axis=1,
mesh_device=mesh_device,
memory_config=ln_sharded_stats_memcfg,
topology=ttnn.Topology.Linear,
)

# Output memory config
if output_core_grid is None:
output_core_grid = input_core_grid

if isinstance(output_core_grid, ttnn.CoreGrid):
output_core_range_set = ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(output_core_grid.x - 1, output_core_grid.y - 1)),
]
)
else:
output_core_range_set = ttnn.CoreRangeSet(
[
ttnn.CoreRange(
ttnn.CoreCoord(x, y),
ttnn.CoreCoord(x, y),
)
for x, y in output_core_grid
]
)
padded_out_w = math.ceil(input_shape[3] / num_devices_fractured / output_core_range_set.num_cores() / 32) * 32
output_memory_config = ttnn.create_sharded_memory_config(
shape=(
input_shape[0] * input_shape[1] * input_shape[2],
padded_out_w,
),
core_grid=output_core_range_set,
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)

# Run distributed rmsnorm part 2
tt_out = ttnn.rms_norm_post_all_gather(
input_tensor,
epsilon=1e-05,
weight=gamma_tensor,
program_config=ln_prg_cfg,
stats=tt_stats,
memory_config=output_memory_config,
dtype=ttnn.bfloat8_b,
)

tt_stats.deallocate(True)
tt_out_torch = ttnn.to_torch(
tt_out, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(0, 3), mesh_shape=(8, 4))
)[0].unsqueeze(0)

ref_lnorm = rms_norm(input_tensor_torch, [3], gamma_torch, torch.zeros_like(gamma_torch), 1e-5)
passing, output = comp_pcc(tt_out_torch, ref_lnorm, 0.999)
logger.info(output)

assert passing
Loading

0 comments on commit 4034423

Please sign in to comment.