From 4f68b2013a6c142f41db3710261779d2fc17bfc3 Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Fri, 7 Jun 2024 12:55:28 +0000 Subject: [PATCH] #9150: Add optional output tensor support for prod along N,C dimension --- .../unit_testing/test_prod_nc.py | 42 ++++++++++- .../op_library/composite/composite_ops.cpp | 70 ++++++++++--------- .../op_library/composite/composite_ops.hpp | 3 +- .../tt_lib_bindings_tensor_composite_ops.cpp | 4 +- 4 files changed, 84 insertions(+), 35 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py b/tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py index c761082e37a..a837893ac1f 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py @@ -5,9 +5,12 @@ import pytest import torch from loguru import logger - +from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt import tt_lib as ttl from models.utility_functions import comp_allclose_and_pcc +from tests.tt_eager.python_api_testing.sweep_tests import ( + comparison_funcs, +) TILE_HEIGHT = 32 TILE_WIDTH = 32 @@ -88,3 +91,40 @@ def test_prod_dims(input_shape, dims, device): logger.info(f"Output pcc={output_pcc}") assert passing + + +mem_configs = [ + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM), + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), +] + + +@pytest.mark.parametrize( + "dst_mem_config", + mem_configs, +) +@pytest.mark.parametrize( + "input_shapes", + ( + ([1, 1, 32, 32]), + ([2, 2, 32, 32]), + ([4, 3, 32, 32]), + ), +) +@pytest.mark.parametrize( + "dim", + [0, 1], +) +@pytest.mark.parametrize("all_dimensions", [False]) +def test_prod_with_output_nc(input_shapes, all_dimensions, dim, dst_mem_config, device): + in_data, input_tensor = data_gen_pt_tt(input_shapes, device) + golden_tensor = torch.prod(in_data, dim, keepdim=True) + + output_shape = input_shapes + output_shape[dim] = 1 + out_data, output_tensor = data_gen_pt_tt(output_shape, device) + + tt_output_tensor_on_device = ttl.tensor.prod(input_tensor, all_dimensions, dim, dst_mem_config, output_tensor) + tt_out_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + comp_pass, comp_out = comparison_funcs.comp_pcc(golden_tensor, tt_out_tensor) + assert comp_pass diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index f781f1a6683..f93c094000e 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -1072,7 +1072,7 @@ Tensor prod_all(const Tensor& input_a, const MemoryConfig& output_mem_config) { return tt::operations::primary::prod_all(formatted_input_tensor, output_mem_config); } -Tensor prod_nc(const Tensor& temp, int64_t dim, const MemoryConfig& output_mem_config) { +Tensor prod_nc(const Tensor& temp, int64_t dim, const MemoryConfig& output_mem_config, std::optional output_tensor = std::nullopt) { // layout conversion auto formatted_input_tensor = temp; if (formatted_input_tensor.get_layout() == Layout::ROW_MAJOR) { @@ -1086,43 +1086,44 @@ Tensor prod_nc(const Tensor& temp, int64_t dim, const MemoryConfig& output_mem_c } // Apply prod std::vector dimension = {(dim == 1 || dim == -3) ? 1 : 0}; - Shape input_shape = formatted_input_tensor.get_legacy_shape(); - Shape required = { - ((dim == 1 || dim == -3) ? input_shape[0] : 1), - ((dim == 1 || dim == -3) ? 1 : input_shape[1]), - input_shape[2], - input_shape[3]}; - return tt::operations::primary::prod_nc( - formatted_input_tensor, - zeros( - required, - formatted_input_tensor.get_dtype(), - formatted_input_tensor.get_layout(), - formatted_input_tensor.device(), - output_mem_config), - dimension, - output_mem_config); + if(output_tensor.has_value()){ + tt::operations::primary::prod_nc( + formatted_input_tensor, + output_tensor.value(), //optional output tensor + dimension, + output_mem_config); + } else { + Shape input_shape = formatted_input_tensor.get_legacy_shape(); + Shape required = { ((dim == 1 || dim == -3) ? input_shape[0] : 1), ((dim == 1 || dim == -3) ? 1 : input_shape[1]), input_shape[2], input_shape[3]}; + output_tensor = zeros( required, formatted_input_tensor.get_dtype(), formatted_input_tensor.get_layout(), formatted_input_tensor.device(), output_mem_config); + output_tensor = tt::operations::primary::prod_nc( + formatted_input_tensor, + output_tensor.value(), + dimension, + output_mem_config); + } + return output_tensor.value(); } -Tensor _prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config) { +Tensor _prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config, std::optional output_tensor) { if (all_dimensions) { return tt::tt_metal::prod_all(input_a, output_mem_config); } TT_FATAL(dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3]"); - Tensor temp = input_a; - // Permute for dim 2,3 - if (dim == 2 || dim == -2) { - std::vector permute_dims = {2, 0, 1, 3}; - temp = permute(input_a, permute_dims, output_mem_config); - } else if (dim == 3 || dim == -1) { - std::vector permute_dims = {3, 0, 1, 2}; - temp = permute(input_a, permute_dims, output_mem_config); - } - Tensor result = tt::tt_metal::prod_nc(temp, dim, output_mem_config); - // Permute and unpad result for dim 2,3 if (dim == 0 || dim == 1 || dim == -4 || dim == -3) { - return result; + if(output_tensor.has_value()){ + tt::tt_metal::prod_nc(input_a, dim, output_mem_config, output_tensor); + } else { + output_tensor = tt::tt_metal::prod_nc(input_a, dim, output_mem_config); + } + return output_tensor.value(); } else if (dim == 2 || dim == -2) { + //Permute + std::vector permute_dims = {2, 0, 1, 3}; + Tensor result = permute(input_a, permute_dims, output_mem_config); + //Prod along N,C + result = tt::tt_metal::prod_nc(result, dim, output_mem_config); + // Permute and unpad std::vector after_permute_dims = {1, 2, 0, 3}; Tensor required = permute(result, after_permute_dims, output_mem_config); Shape input_shape = input_a.get_legacy_shape(); @@ -1130,6 +1131,11 @@ Tensor _prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const Memo const Shape end_index = {input_shape[0] - 1, input_shape[1] - 1, 0, input_shape[3] - 1}; return unpad(required, start_index, end_index); } else { // dim 3 + //Permute + std::vector permute_dims = {3, 0, 1, 2}; + Tensor result = permute(input_a, permute_dims, output_mem_config); + //Prod along N,C + result = tt::tt_metal::prod_nc(result, dim, output_mem_config); // permute std::vector after_permute_dims = {1, 2, 0, 3}; Tensor required = permute(result, after_permute_dims, output_mem_config); @@ -1144,8 +1150,8 @@ Tensor _prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const Memo } } -Tensor prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _prod)(input_a, all_dimensions, dim, output_mem_config); +Tensor prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _prod)(input_a, all_dimensions, dim, output_mem_config, output_tensor); } Tensor _variance_impl( diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index 2b5bd0514ec..5b41f6828e3 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -235,7 +235,8 @@ Tensor prod( const Tensor& input_a, bool all_dimensions = false, int64_t dim = 0, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); /* diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index 274710fe439..2b0406f8558 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -212,9 +212,10 @@ namespace tt::tt_metal::detail{ R"doc(dimension to split)doc" ); m_tensor.def("prod", &prod, - py::arg("input").noconvert(), py::arg("all_dimensions") = false, py::arg("dim") = 0, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + py::arg("input").noconvert(), py::arg("all_dimensions") = false, py::arg("dim") = 0, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Computes the prod function along specified ``dim`` or all dimensions on the ``input`` tensor. If ``all_dimensions`` is set to ``true`` irrespective of given dimension it will prod along all dimensions. + Optional ``output_tensor`` support provided only for ``dim = 0, 1`` . Input tensor must have BFLOAT16 data type. @@ -227,6 +228,7 @@ namespace tt::tt_metal::detail{ "all_dimensions", "Consider all dimension (ignores ``dim`` param)", "bool", "default to false", "No" "dim", "Dimension to perform prod", "int", "default to 0", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); detail::bind_unary_op_with_param( m_tensor, "geglu", &geglu,