Skip to content

Commit

Permalink
#14208: Refactor and add bfloat8_b unit test moreh_mean (forward, bac…
Browse files Browse the repository at this point in the history
…kward)
  • Loading branch information
mrshaw01 committed Oct 24, 2024
1 parent 8df083a commit f4d95c5
Showing 1 changed file with 94 additions and 104 deletions.
198 changes: 94 additions & 104 deletions tests/ttnn/unit_tests/operations/test_moreh_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,139 +6,115 @@
import torch
from loguru import logger

import tt_lib as ttl
import ttnn
from models.utility_functions import comp_allclose

from tests.ttnn.unit_tests.operations.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
to_torch,
to_ttnn,
TILE_HEIGHT,
TILE_WIDTH,
check_dim,
)


def get_torch_tensors(input_shape, use_randint=False):
cpu_dtype = torch.bfloat16
def create_ttnn_tilized_tensor(torch_tensor, device, dtype):
return ttnn.from_torch(torch_tensor, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT)

if use_randint:
torch_input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype)
else:
torch_input = torch.rand(input_shape, dtype=cpu_dtype)

torch_input.requires_grad_()
return torch_input


def get_tt_tensors(torch_input, output_shape, device):
torch_input = torch_input.bfloat16()
torch_output = torch.empty(output_shape, dtype=torch.bfloat16)

tt_input = to_ttnn(torch_input, device=device)
tt_output = to_ttnn(torch_output, device=device)
return tt_input, tt_output


def get_torch_backward_tensors(output_grad_shape, use_randint=False):
cpu_dtype = torch.bfloat16

if use_randint:
torch_output_grad = torch.randint(-2, 3, output_grad_shape, dtype=cpu_dtype)
else:
torch_output_grad = torch.rand(output_grad_shape, dtype=cpu_dtype)

return torch_output_grad


def get_tt_backward_tensors(torch_output_grad, input_grad_shape, device):
cpu_dtype = torch.bfloat16

torch_input_grad = torch.empty(input_grad_shape, dtype=cpu_dtype)

tt_input_grad = to_ttnn(torch_input_grad, device=device)
tt_output_grad = to_ttnn(torch_output_grad, device=device)
def run_moreh_mean(
input_shape_dim,
device,
*,
keepdim=False,
compute_kernel_options=None,
optional_output=False,
torch_dtype=torch.float32,
ttnn_dtype=ttnn.bfloat16,
):
# TODO @mrshaw01: Support bfloat8_b in kernel
if ttnn_dtype == ttnn.bfloat8_b:
pytest.skip(f"bfloat8_b is not supported in the kernel")

return tt_output_grad, tt_input_grad


def run_moreh_mean(input_shape_dim, device, keepdim=False, compute_kernel_options=None):
input_shape, dim = input_shape_dim

check_dim(input_shape, dim, keepdim)

compute_kernel_config = get_compute_kernel_options(compute_kernel_options)

# run torch
torch_input = get_torch_tensors(input_shape)
torch_input = torch.rand(input_shape, dtype=torch_dtype)
torch_output = torch.mean(torch_input, dim=dim, keepdim=keepdim)

# run tt
(tt_input, tt_output) = get_tt_tensors(torch_input, torch_output.shape, device)

ttnn.operations.moreh.mean(
tt_input, dim=dim, keepdim=keepdim, output=tt_output, compute_kernel_config=compute_kernel_config
# run ttnn
ttnn_input = create_ttnn_tilized_tensor(torch_input, device, ttnn_dtype)
ttnn_output = (
create_ttnn_tilized_tensor(torch.empty_like(torch_output), device, ttnn_dtype) if optional_output else None
)
ttnn_output = ttnn.operations.moreh.mean(
ttnn_input,
dim=dim,
keepdim=keepdim,
output=ttnn_output,
compute_kernel_config=get_compute_kernel_options(compute_kernel_options),
)
tt_output_cpu = to_torch(tt_output, shape=torch_output.shape)
output = ttnn.to_torch(ttnn_output)

# test for equivalance
rtol = atol = 0.1
passing, output_pcc = comp_allclose(torch_output, tt_output_cpu, rtol=rtol, atol=atol)

logger.debug(f"Out passing={passing}")
logger.debug(f"Output pcc={output_pcc}")

passing, out = comp_allclose(torch_output, output, rtol=rtol, atol=atol)
logger.info(f"passing={passing}")
logger.info(f"out={out}")
assert passing


def run_moreh_mean_backward(input_shape_dim, device, keepdim=False, compute_kernel_options=None, create_output=False):
input_shape, dim = input_shape_dim
def run_moreh_mean_backward(
input_shape_dim,
device,
*,
keepdim=False,
compute_kernel_options=None,
create_input_grad=False,
torch_dtype=torch.float32,
ttnn_dtype=ttnn.bfloat16,
):
# TODO @mrshaw01: Support bfloat8_b in kernel
if ttnn_dtype == ttnn.bfloat8_b:
pytest.skip(f"bfloat8_b is not supported in the kernel")

input_shape, dim = input_shape_dim
check_dim(input_shape, dim, keepdim)

compute_kernel_config = get_compute_kernel_options(compute_kernel_options)

# run torch
torch_input = get_torch_tensors(input_shape)
torch_input = torch.rand(input_shape, dtype=torch_dtype)
torch_input.requires_grad_()
torch_output = torch.mean(torch_input, dim=dim, keepdim=keepdim)

torch_output_grad = get_torch_backward_tensors(torch_output.shape)

torch_output_grad = torch.rand(torch_output.shape, dtype=torch_dtype)
torch_output.backward(torch_output_grad)

# run_tt
tt_output_grad, tt_input_grad = get_tt_backward_tensors(torch_output_grad, torch_input.shape, device)

if create_output:
# run ttnn
ttnn_output_grad = create_ttnn_tilized_tensor(torch_output_grad, device, ttnn_dtype)
if create_input_grad:
input_grad_shape = ttnn._ttnn.types.Shape(torch_input.shape)
tt_input_grad = ttnn.operations.moreh.mean_backward(
tt_output_grad,
ttnn_input_grad = ttnn.operations.moreh.mean_backward(
ttnn_output_grad,
dim=dim,
keepdim=keepdim,
input_grad_shape=input_grad_shape,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=get_compute_kernel_options(compute_kernel_options),
)
else:
ttnn_input_grad = create_ttnn_tilized_tensor(torch.empty_like(torch_input), device, ttnn_dtype)
ttnn.operations.moreh.mean_backward(
tt_output_grad,
ttnn_output_grad,
dim=dim,
keepdim=keepdim,
input_grad=tt_input_grad,
compute_kernel_config=compute_kernel_config,
input_grad=ttnn_input_grad,
compute_kernel_config=get_compute_kernel_options(compute_kernel_options),
)
input_grad = ttnn.to_torch(ttnn_input_grad)

tt_input_grad_cpu = to_torch(tt_input_grad, shape=torch_input.grad.shape)

# test for equivalance
rtol = atol = 0.1
passing, output_pcc = comp_allclose(torch_input.grad, tt_input_grad_cpu, rtol=rtol, atol=atol)

logger.debug(f"Out passing={passing}")
logger.debug(f"Output pcc={output_pcc}")

passing, output_pcc = comp_allclose(torch_input.grad, input_grad, rtol=rtol, atol=atol)
logger.info(f"Out passing={passing}")
logger.info(f"Output pcc={output_pcc}")
assert passing


Expand Down Expand Up @@ -168,10 +144,10 @@ def run_moreh_mean_backward(input_shape_dim, device, keepdim=False, compute_kern
],
)
@pytest.mark.parametrize("keepdim", [True, False])
def test_moreh_mean(input_shape_dim, keepdim, device):
torch.manual_seed(2023)

run_moreh_mean(input_shape_dim, device, keepdim)
@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat8_b, ttnn.bfloat16])
def test_moreh_mean_ttnn_dtype(input_shape_dim, keepdim, ttnn_dtype, device):
torch.manual_seed(2024)
run_moreh_mean(input_shape_dim, device, keepdim=keepdim, ttnn_dtype=ttnn_dtype)


@pytest.mark.parametrize(
Expand All @@ -186,11 +162,26 @@ def test_moreh_mean(input_shape_dim, keepdim, device):
)
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_mean_compute_kernel_options(input_shape_dim, compute_kernel_options, device):
torch.manual_seed(2023)

torch.manual_seed(2024)
run_moreh_mean(input_shape_dim, device, keepdim=True, compute_kernel_options=compute_kernel_options)


@pytest.mark.parametrize(
"input_shape_dim",
[
# hw multiple tiles
[[TILE_HEIGHT * 4, TILE_WIDTH * 5], [0]], # h
[[TILE_HEIGHT * 4, TILE_WIDTH * 5], [1]], # w
# ncd multiple tile
[[3, 4, 5, TILE_HEIGHT * 3 - 15, TILE_WIDTH * 4 - 10], [1]], # c
],
)
@pytest.mark.parametrize("optional_output", [True, False])
def test_moreh_mean_optional_output(input_shape_dim, optional_output, device):
torch.manual_seed(2024)
run_moreh_mean(input_shape_dim, device, keepdim=True, optional_output=optional_output)


@pytest.mark.parametrize(
"input_shape_dim",
[
Expand All @@ -207,7 +198,7 @@ def test_moreh_mean_callback(input_shape_dim, device, use_program_cache):
for i in range(2):
run_moreh_mean(input_shape_dim, device, keepdim=True)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
ttnn_dummy = ttnn.from_torch(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
Expand Down Expand Up @@ -240,10 +231,10 @@ def test_moreh_mean_callback(input_shape_dim, device, use_program_cache):
],
)
@pytest.mark.parametrize("keepdim", [True, False])
def test_moreh_mean_backward(input_shape_dim, keepdim, device):
torch.manual_seed(2023)

run_moreh_mean_backward(input_shape_dim, device, keepdim=keepdim)
@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat8_b, ttnn.bfloat16])
def test_moreh_mean_backward_ttnn_dtype(ttnn_dtype, input_shape_dim, keepdim, device):
torch.manual_seed(2024)
run_moreh_mean_backward(input_shape_dim, device, keepdim=keepdim, ttnn_dtype=ttnn_dtype)


@pytest.mark.parametrize(
Expand All @@ -258,8 +249,7 @@ def test_moreh_mean_backward(input_shape_dim, keepdim, device):
)
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_mean_backward_compute_kernel_options(input_shape_dim, compute_kernel_options, device):
torch.manual_seed(2023)

torch.manual_seed(2024)
run_moreh_mean_backward(input_shape_dim, device, keepdim=True, compute_kernel_options=compute_kernel_options)


Expand All @@ -279,7 +269,7 @@ def test_moreh_mean_backward_callback(input_shape_dim, device, use_program_cache
for i in range(2):
run_moreh_mean_backward(input_shape_dim, device, keepdim=True)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
ttnn_dummy = ttnn.from_torch(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
Expand All @@ -297,7 +287,7 @@ def test_moreh_mean_backward_callback(input_shape_dim, device, use_program_cache
],
)
@pytest.mark.parametrize("keepdim", [True, False])
def test_moreh_mean_backward_create_output(input_shape_dim, keepdim, device):
torch.manual_seed(2023)

run_moreh_mean_backward(input_shape_dim, device, keepdim=keepdim, create_output=True)
@pytest.mark.parametrize("create_input_grad", [True, False])
def test_moreh_mean_backward_create_input_grad(input_shape_dim, keepdim, create_input_grad, device):
torch.manual_seed(2024)
run_moreh_mean_backward(input_shape_dim, device, keepdim=keepdim, create_input_grad=create_input_grad)

0 comments on commit f4d95c5

Please sign in to comment.