Skip to content

Commit

Permalink
#9150: Add optional output tensor support for prod along N,C dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jun 8, 2024
1 parent b081f5e commit 4f68b20
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 35 deletions.
42 changes: 41 additions & 1 deletion tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import pytest
import torch
from loguru import logger

from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt
import tt_lib as ttl
from models.utility_functions import comp_allclose_and_pcc
from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
)

TILE_HEIGHT = 32
TILE_WIDTH = 32
Expand Down Expand Up @@ -88,3 +91,40 @@ def test_prod_dims(input_shape, dims, device):
logger.info(f"Output pcc={output_pcc}")

assert passing


mem_configs = [
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
]


@pytest.mark.parametrize(
"dst_mem_config",
mem_configs,
)
@pytest.mark.parametrize(
"input_shapes",
(
([1, 1, 32, 32]),
([2, 2, 32, 32]),
([4, 3, 32, 32]),
),
)
@pytest.mark.parametrize(
"dim",
[0, 1],
)
@pytest.mark.parametrize("all_dimensions", [False])
def test_prod_with_output_nc(input_shapes, all_dimensions, dim, dst_mem_config, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device)
golden_tensor = torch.prod(in_data, dim, keepdim=True)

output_shape = input_shapes
output_shape[dim] = 1
out_data, output_tensor = data_gen_pt_tt(output_shape, device)

tt_output_tensor_on_device = ttl.tensor.prod(input_tensor, all_dimensions, dim, dst_mem_config, output_tensor)
tt_out_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
comp_pass, comp_out = comparison_funcs.comp_pcc(golden_tensor, tt_out_tensor)
assert comp_pass
70 changes: 38 additions & 32 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ Tensor prod_all(const Tensor& input_a, const MemoryConfig& output_mem_config) {
return tt::operations::primary::prod_all(formatted_input_tensor, output_mem_config);
}

Tensor prod_nc(const Tensor& temp, int64_t dim, const MemoryConfig& output_mem_config) {
Tensor prod_nc(const Tensor& temp, int64_t dim, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor = std::nullopt) {
// layout conversion
auto formatted_input_tensor = temp;
if (formatted_input_tensor.get_layout() == Layout::ROW_MAJOR) {
Expand All @@ -1086,50 +1086,56 @@ Tensor prod_nc(const Tensor& temp, int64_t dim, const MemoryConfig& output_mem_c
}
// Apply prod
std::vector<int64_t> dimension = {(dim == 1 || dim == -3) ? 1 : 0};
Shape input_shape = formatted_input_tensor.get_legacy_shape();
Shape required = {
((dim == 1 || dim == -3) ? input_shape[0] : 1),
((dim == 1 || dim == -3) ? 1 : input_shape[1]),
input_shape[2],
input_shape[3]};
return tt::operations::primary::prod_nc(
formatted_input_tensor,
zeros(
required,
formatted_input_tensor.get_dtype(),
formatted_input_tensor.get_layout(),
formatted_input_tensor.device(),
output_mem_config),
dimension,
output_mem_config);
if(output_tensor.has_value()){
tt::operations::primary::prod_nc(
formatted_input_tensor,
output_tensor.value(), //optional output tensor
dimension,
output_mem_config);
} else {
Shape input_shape = formatted_input_tensor.get_legacy_shape();
Shape required = { ((dim == 1 || dim == -3) ? input_shape[0] : 1), ((dim == 1 || dim == -3) ? 1 : input_shape[1]), input_shape[2], input_shape[3]};
output_tensor = zeros( required, formatted_input_tensor.get_dtype(), formatted_input_tensor.get_layout(), formatted_input_tensor.device(), output_mem_config);
output_tensor = tt::operations::primary::prod_nc(
formatted_input_tensor,
output_tensor.value(),
dimension,
output_mem_config);
}
return output_tensor.value();
}

Tensor _prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config) {
Tensor _prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
if (all_dimensions) {
return tt::tt_metal::prod_all(input_a, output_mem_config);
}
TT_FATAL(dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3]");
Tensor temp = input_a;
// Permute for dim 2,3
if (dim == 2 || dim == -2) {
std::vector<int64_t> permute_dims = {2, 0, 1, 3};
temp = permute(input_a, permute_dims, output_mem_config);
} else if (dim == 3 || dim == -1) {
std::vector<int64_t> permute_dims = {3, 0, 1, 2};
temp = permute(input_a, permute_dims, output_mem_config);
}
Tensor result = tt::tt_metal::prod_nc(temp, dim, output_mem_config);
// Permute and unpad result for dim 2,3
if (dim == 0 || dim == 1 || dim == -4 || dim == -3) {
return result;
if(output_tensor.has_value()){
tt::tt_metal::prod_nc(input_a, dim, output_mem_config, output_tensor);
} else {
output_tensor = tt::tt_metal::prod_nc(input_a, dim, output_mem_config);
}
return output_tensor.value();
} else if (dim == 2 || dim == -2) {
//Permute
std::vector<int64_t> permute_dims = {2, 0, 1, 3};
Tensor result = permute(input_a, permute_dims, output_mem_config);
//Prod along N,C
result = tt::tt_metal::prod_nc(result, dim, output_mem_config);
// Permute and unpad
std::vector<int64_t> after_permute_dims = {1, 2, 0, 3};
Tensor required = permute(result, after_permute_dims, output_mem_config);
Shape input_shape = input_a.get_legacy_shape();
const Shape start_index = {0, 0, 0, 0};
const Shape end_index = {input_shape[0] - 1, input_shape[1] - 1, 0, input_shape[3] - 1};
return unpad(required, start_index, end_index);
} else { // dim 3
//Permute
std::vector<int64_t> permute_dims = {3, 0, 1, 2};
Tensor result = permute(input_a, permute_dims, output_mem_config);
//Prod along N,C
result = tt::tt_metal::prod_nc(result, dim, output_mem_config);
// permute
std::vector<int64_t> after_permute_dims = {1, 2, 0, 3};
Tensor required = permute(result, after_permute_dims, output_mem_config);
Expand All @@ -1144,8 +1150,8 @@ Tensor _prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const Memo
}
}

Tensor prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _prod)(input_a, all_dimensions, dim, output_mem_config);
Tensor prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _prod)(input_a, all_dimensions, dim, output_mem_config, output_tensor);
}

Tensor _variance_impl(
Expand Down
3 changes: 2 additions & 1 deletion tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ Tensor prod(
const Tensor& input_a,
bool all_dimensions = false,
int64_t dim = 0,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);


/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,10 @@ namespace tt::tt_metal::detail{
R"doc(dimension to split)doc"
);
m_tensor.def("prod", &prod,
py::arg("input").noconvert(), py::arg("all_dimensions") = false, py::arg("dim") = 0, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
py::arg("input").noconvert(), py::arg("all_dimensions") = false, py::arg("dim") = 0, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Computes the prod function along specified ``dim`` or all dimensions on the ``input`` tensor.
If ``all_dimensions`` is set to ``true`` irrespective of given dimension it will prod along all dimensions.
Optional ``output_tensor`` support provided only for ``dim = 0, 1`` .
Input tensor must have BFLOAT16 data type.
Expand All @@ -227,6 +228,7 @@ namespace tt::tt_metal::detail{
"all_dimensions", "Consider all dimension (ignores ``dim`` param)", "bool", "default to false", "No"
"dim", "Dimension to perform prod", "int", "default to 0", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
detail::bind_unary_op_with_param(
m_tensor, "geglu", &geglu,
Expand Down

0 comments on commit 4f68b20

Please sign in to comment.