diff --git a/tests/ttnn/unit_tests/operations/test_moreh_mean.py b/tests/ttnn/unit_tests/operations/test_moreh_mean.py index 2f8ca1319b8..961fe74fb79 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_mean.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_mean.py @@ -6,7 +6,6 @@ import torch from loguru import logger -import tt_lib as ttl import ttnn from models.utility_functions import comp_allclose @@ -14,131 +13,108 @@ get_compute_kernel_options, compute_kernel_options, compute_kernel_ids, - to_torch, - to_ttnn, TILE_HEIGHT, TILE_WIDTH, check_dim, ) -def get_torch_tensors(input_shape, use_randint=False): - cpu_dtype = torch.bfloat16 +def create_ttnn_tilized_tensor(torch_tensor, device, dtype): + return ttnn.from_torch(torch_tensor, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT) - if use_randint: - torch_input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype) - else: - torch_input = torch.rand(input_shape, dtype=cpu_dtype) - - torch_input.requires_grad_() - return torch_input - - -def get_tt_tensors(torch_input, output_shape, device): - torch_input = torch_input.bfloat16() - torch_output = torch.empty(output_shape, dtype=torch.bfloat16) - - tt_input = to_ttnn(torch_input, device=device) - tt_output = to_ttnn(torch_output, device=device) - return tt_input, tt_output - - -def get_torch_backward_tensors(output_grad_shape, use_randint=False): - cpu_dtype = torch.bfloat16 - - if use_randint: - torch_output_grad = torch.randint(-2, 3, output_grad_shape, dtype=cpu_dtype) - else: - torch_output_grad = torch.rand(output_grad_shape, dtype=cpu_dtype) - - return torch_output_grad - - -def get_tt_backward_tensors(torch_output_grad, input_grad_shape, device): - cpu_dtype = torch.bfloat16 - - torch_input_grad = torch.empty(input_grad_shape, dtype=cpu_dtype) - tt_input_grad = to_ttnn(torch_input_grad, device=device) - tt_output_grad = to_ttnn(torch_output_grad, device=device) +def run_moreh_mean( + input_shape_dim, + device, + *, + keepdim=False, + compute_kernel_options=None, + optional_output=False, + torch_dtype=torch.float32, + ttnn_dtype=ttnn.bfloat16, +): + # TODO @mrshaw01: Support bfloat8_b in kernel + if ttnn_dtype == ttnn.bfloat8_b: + pytest.skip(f"bfloat8_b is not supported in the kernel") - return tt_output_grad, tt_input_grad - - -def run_moreh_mean(input_shape_dim, device, keepdim=False, compute_kernel_options=None): input_shape, dim = input_shape_dim - check_dim(input_shape, dim, keepdim) - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - # run torch - torch_input = get_torch_tensors(input_shape) + torch_input = torch.rand(input_shape, dtype=torch_dtype) torch_output = torch.mean(torch_input, dim=dim, keepdim=keepdim) - # run tt - (tt_input, tt_output) = get_tt_tensors(torch_input, torch_output.shape, device) - - ttnn.operations.moreh.mean( - tt_input, dim=dim, keepdim=keepdim, output=tt_output, compute_kernel_config=compute_kernel_config + # run ttnn + ttnn_input = create_ttnn_tilized_tensor(torch_input, device, ttnn_dtype) + ttnn_output = ( + create_ttnn_tilized_tensor(torch.empty_like(torch_output), device, ttnn_dtype) if optional_output else None + ) + ttnn_output = ttnn.operations.moreh.mean( + ttnn_input, + dim=dim, + keepdim=keepdim, + output=ttnn_output, + compute_kernel_config=get_compute_kernel_options(compute_kernel_options), ) - tt_output_cpu = to_torch(tt_output, shape=torch_output.shape) + output = ttnn.to_torch(ttnn_output) - # test for equivalance rtol = atol = 0.1 - passing, output_pcc = comp_allclose(torch_output, tt_output_cpu, rtol=rtol, atol=atol) - - logger.debug(f"Out passing={passing}") - logger.debug(f"Output pcc={output_pcc}") - + passing, out = comp_allclose(torch_output, output, rtol=rtol, atol=atol) + logger.info(f"passing={passing}") + logger.info(f"out={out}") assert passing -def run_moreh_mean_backward(input_shape_dim, device, keepdim=False, compute_kernel_options=None, create_output=False): - input_shape, dim = input_shape_dim +def run_moreh_mean_backward( + input_shape_dim, + device, + *, + keepdim=False, + compute_kernel_options=None, + create_input_grad=False, + torch_dtype=torch.float32, + ttnn_dtype=ttnn.bfloat16, +): + # TODO @mrshaw01: Support bfloat8_b in kernel + if ttnn_dtype == ttnn.bfloat8_b: + pytest.skip(f"bfloat8_b is not supported in the kernel") + input_shape, dim = input_shape_dim check_dim(input_shape, dim, keepdim) - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - # run torch - torch_input = get_torch_tensors(input_shape) + torch_input = torch.rand(input_shape, dtype=torch_dtype) + torch_input.requires_grad_() torch_output = torch.mean(torch_input, dim=dim, keepdim=keepdim) - - torch_output_grad = get_torch_backward_tensors(torch_output.shape) - + torch_output_grad = torch.rand(torch_output.shape, dtype=torch_dtype) torch_output.backward(torch_output_grad) - # run_tt - tt_output_grad, tt_input_grad = get_tt_backward_tensors(torch_output_grad, torch_input.shape, device) - - if create_output: + # run ttnn + ttnn_output_grad = create_ttnn_tilized_tensor(torch_output_grad, device, ttnn_dtype) + if create_input_grad: input_grad_shape = ttnn._ttnn.types.Shape(torch_input.shape) - tt_input_grad = ttnn.operations.moreh.mean_backward( - tt_output_grad, + ttnn_input_grad = ttnn.operations.moreh.mean_backward( + ttnn_output_grad, dim=dim, keepdim=keepdim, input_grad_shape=input_grad_shape, - compute_kernel_config=compute_kernel_config, + compute_kernel_config=get_compute_kernel_options(compute_kernel_options), ) else: + ttnn_input_grad = create_ttnn_tilized_tensor(torch.empty_like(torch_input), device, ttnn_dtype) ttnn.operations.moreh.mean_backward( - tt_output_grad, + ttnn_output_grad, dim=dim, keepdim=keepdim, - input_grad=tt_input_grad, - compute_kernel_config=compute_kernel_config, + input_grad=ttnn_input_grad, + compute_kernel_config=get_compute_kernel_options(compute_kernel_options), ) + input_grad = ttnn.to_torch(ttnn_input_grad) - tt_input_grad_cpu = to_torch(tt_input_grad, shape=torch_input.grad.shape) - - # test for equivalance rtol = atol = 0.1 - passing, output_pcc = comp_allclose(torch_input.grad, tt_input_grad_cpu, rtol=rtol, atol=atol) - - logger.debug(f"Out passing={passing}") - logger.debug(f"Output pcc={output_pcc}") - + passing, output_pcc = comp_allclose(torch_input.grad, input_grad, rtol=rtol, atol=atol) + logger.info(f"Out passing={passing}") + logger.info(f"Output pcc={output_pcc}") assert passing @@ -168,10 +144,10 @@ def run_moreh_mean_backward(input_shape_dim, device, keepdim=False, compute_kern ], ) @pytest.mark.parametrize("keepdim", [True, False]) -def test_moreh_mean(input_shape_dim, keepdim, device): - torch.manual_seed(2023) - - run_moreh_mean(input_shape_dim, device, keepdim) +@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat8_b, ttnn.bfloat16]) +def test_moreh_mean_ttnn_dtype(input_shape_dim, keepdim, ttnn_dtype, device): + torch.manual_seed(2024) + run_moreh_mean(input_shape_dim, device, keepdim=keepdim, ttnn_dtype=ttnn_dtype) @pytest.mark.parametrize( @@ -186,11 +162,26 @@ def test_moreh_mean(input_shape_dim, keepdim, device): ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_moreh_mean_compute_kernel_options(input_shape_dim, compute_kernel_options, device): - torch.manual_seed(2023) - + torch.manual_seed(2024) run_moreh_mean(input_shape_dim, device, keepdim=True, compute_kernel_options=compute_kernel_options) +@pytest.mark.parametrize( + "input_shape_dim", + [ + # hw multiple tiles + [[TILE_HEIGHT * 4, TILE_WIDTH * 5], [0]], # h + [[TILE_HEIGHT * 4, TILE_WIDTH * 5], [1]], # w + # ncd multiple tile + [[3, 4, 5, TILE_HEIGHT * 3 - 15, TILE_WIDTH * 4 - 10], [1]], # c + ], +) +@pytest.mark.parametrize("optional_output", [True, False]) +def test_moreh_mean_optional_output(input_shape_dim, optional_output, device): + torch.manual_seed(2024) + run_moreh_mean(input_shape_dim, device, keepdim=True, optional_output=optional_output) + + @pytest.mark.parametrize( "input_shape_dim", [ @@ -207,7 +198,7 @@ def test_moreh_mean_callback(input_shape_dim, device, use_program_cache): for i in range(2): run_moreh_mean(input_shape_dim, device, keepdim=True) torch_dummy = torch.randn([32, 32]) - tt_dummy = to_ttnn(torch_dummy, device=device) + ttnn_dummy = ttnn.from_torch(torch_dummy, device=device) num_program_cache_entries_list.append(device.num_program_cache_entries()) logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}") assert num_program_cache_entries_list[0] > 0 @@ -240,10 +231,10 @@ def test_moreh_mean_callback(input_shape_dim, device, use_program_cache): ], ) @pytest.mark.parametrize("keepdim", [True, False]) -def test_moreh_mean_backward(input_shape_dim, keepdim, device): - torch.manual_seed(2023) - - run_moreh_mean_backward(input_shape_dim, device, keepdim=keepdim) +@pytest.mark.parametrize("ttnn_dtype", [ttnn.bfloat8_b, ttnn.bfloat16]) +def test_moreh_mean_backward_ttnn_dtype(ttnn_dtype, input_shape_dim, keepdim, device): + torch.manual_seed(2024) + run_moreh_mean_backward(input_shape_dim, device, keepdim=keepdim, ttnn_dtype=ttnn_dtype) @pytest.mark.parametrize( @@ -258,8 +249,7 @@ def test_moreh_mean_backward(input_shape_dim, keepdim, device): ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_moreh_mean_backward_compute_kernel_options(input_shape_dim, compute_kernel_options, device): - torch.manual_seed(2023) - + torch.manual_seed(2024) run_moreh_mean_backward(input_shape_dim, device, keepdim=True, compute_kernel_options=compute_kernel_options) @@ -279,7 +269,7 @@ def test_moreh_mean_backward_callback(input_shape_dim, device, use_program_cache for i in range(2): run_moreh_mean_backward(input_shape_dim, device, keepdim=True) torch_dummy = torch.randn([32, 32]) - tt_dummy = to_ttnn(torch_dummy, device=device) + ttnn_dummy = ttnn.from_torch(torch_dummy, device=device) num_program_cache_entries_list.append(device.num_program_cache_entries()) logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}") assert num_program_cache_entries_list[0] > 0 @@ -297,7 +287,7 @@ def test_moreh_mean_backward_callback(input_shape_dim, device, use_program_cache ], ) @pytest.mark.parametrize("keepdim", [True, False]) -def test_moreh_mean_backward_create_output(input_shape_dim, keepdim, device): - torch.manual_seed(2023) - - run_moreh_mean_backward(input_shape_dim, device, keepdim=keepdim, create_output=True) +@pytest.mark.parametrize("create_input_grad", [True, False]) +def test_moreh_mean_backward_create_input_grad(input_shape_dim, keepdim, create_input_grad, device): + torch.manual_seed(2024) + run_moreh_mean_backward(input_shape_dim, device, keepdim=keepdim, create_input_grad=create_input_grad)