diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index e0566438b7..c4e806a842 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -19,6 +19,9 @@ ConvertSplitToSlicePass, ) from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import ( + InsertSqueezeAfterSumPass, +) from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( ConvertMeanDimToAveragePool, ) @@ -47,6 +50,7 @@ def transform_to_backend_pipeline( self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(ConvertMeanDimToAveragePool()) self.add_pass(DecomposeDivPass()) + self.add_pass(InsertSqueezeAfterSumPass()) self.add_pass(ConvertSplitToSlicePass()) for spec in compile_spec: if spec.key == "permute_memory_format": diff --git a/backends/arm/_passes/insert_squeeze_after_sum_pass.py b/backends/arm/_passes/insert_squeeze_after_sum_pass.py new file mode 100644 index 0000000000..152d5c95f6 --- /dev/null +++ b/backends/arm/_passes/insert_squeeze_after_sum_pass.py @@ -0,0 +1,69 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast + +import torch +import torch.fx +from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair + +from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class InsertSqueezeAfterSumPass(ExportPass): + """ + In Pytorch, the default behaviour of Tensor.sum is to squeeze + the dimension that is summed (keep_dim = False). + However, in TOSA, REDUCE_SUM always preserves the + rank of the input (keep_dim = True). + To get a 1-1 mapping in the sum lowering, normalize the + keep_dim = False case to keep_dim = True and add squeeze ops. + + Original: + sum(dims, keep_dim = False) + After pass: + sum(dims, keep_dim = True) + (q) + (dq) + squeeze(dim = dims) + """ + + def call(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.sum.dim_IntList: + continue + sum_node = cast(torch.fx.Node, node) + keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False) + if keep_dim: + continue + + dim_list = cast(list[int], sum_node.args[1]) + quantized = is_quant_node(sum_node) + if quantized: + qparams = get_quant_node_args(sum_node.all_input_nodes[0]) + qparams = qparams + (torch.int8,) + else: + qparams = None + + # Add keep_dim = True arg to sum node. + sum_node.args = sum_node.args[0:2] + (True,) + + with graph_module.graph.inserting_after(sum_node): + squeeze_node = create_node( + graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, () + ) + sum_node.replace_all_uses_with(squeeze_node) + squeeze_node.args = (sum_node, dim_list) + if quantized: + sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 22fb5ac6ac..7db893694b 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -63,6 +63,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.mean.dim, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 6d08290f03..855487cf7f 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -34,6 +34,7 @@ op_softmax, op_squeeze, op_sub, + op_sum, op_unsqueeze, op_view, ) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py new file mode 100644 index 0000000000..b67f5f92db --- /dev/null +++ b/backends/arm/operators/op_sum.py @@ -0,0 +1,96 @@ +# Copyright 2023-2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, List + +import executorch.backends.arm.tosa_quant_utils as tqutils +import executorch.backends.arm.tosa_utils as tutils + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class AddVisitor(NodeVisitor): + target = "aten.sum.dim_IntList" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + input_node = inputs[0] + input_shape = list(input_node.shape) + dim_list = cast(list[int], inputs[1].special) + dim_list = [dim % len(input_node.shape) for dim in dim_list] + keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) + assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + + if is_quant_node: + + # Rescale input to 32 bit + rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( + [node.all_input_nodes[0]], tosa_graph + ) + + prev_node = rescaled_inputs[0] + reduced_shape = input_shape + + # Reduce all dims in dim_list one-by-one. + for dim in dim_list: + # When reduced, the size of the dim becomes 1. + reduced_shape[dim] = 1 + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(input_node.dim_order.index(dim)) + + next_node = tosa_graph.addIntermediate( + tutils.tosa_shape(reduced_shape, input_node.dim_order), + dtype=ts.DType.INT32, + ) + + tosa_graph.addOperator( + TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr + ) + + prev_node = next_node + tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph) + else: + input_name = input_node.name + reduced_shape = input_shape + + # Reduce all dims in dim_list one-by-one. + for dim in dim_list: + # When reduced, the size of the dim becomes 1 + reduced_shape[dim] = 1 + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(input_node.dim_order.index(dim)) + + if dim == dim_list[-1]: + output_name = output.name + else: + output_name = tosa_graph.addIntermediate( + tutils.tosa_shape(reduced_shape, input_node.dim_order), + dtype=ts.DType.FP32, + ).name + + tosa_graph.addOperator( + TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr + ) + + input_name = output_name diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 180ddcc2d2..6a68eb2eb9 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -271,6 +271,7 @@ class ArmQuantizer(Quantizer): "cat", "one_to_one", "generic", + "sum", ] def __init__(self) -> None: diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index 87d93ce73b..bc3184298f 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -60,4 +60,5 @@ def decorator(annotator: AnnotatorType): mul_annotator, one_to_one_annotator, sub_annotator, + sum_annotator, ) diff --git a/backends/arm/quantizer/quantization_annotation/sum_annotator.py b/backends/arm/quantizer/quantization_annotation/sum_annotator.py new file mode 100644 index 0000000000..1c0399dc34 --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/sum_annotator.py @@ -0,0 +1,57 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, cast, List, Optional + +import torch +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig + +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.fx import Node + + +@register_annotator("sum") +def _annotate_sum( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.target is not torch.ops.aten.sum.dim_IntList: + continue + if filter_fn and not filter_fn(node): + continue + + sum_node = node + if arm_quantizer_utils.is_annotated(sum_node): + continue + + input_act = sum_node.args[0] + + if not isinstance(input_act, Node): + continue + if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm): + continue + + input_act_qspec = cast( + Optional[QuantizationSpecBase], quantization_config.get_input_act_qspec() + ) + input_qspec_map = {input_act: input_act_qspec} + shared_with_input0_qspec = SharedQuantizationSpec((input_act, sum_node)) + + sum_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_with_input0_qspec, + _annotated=True, + ) + annotated_partitions.append([sum_node]) + return annotated_partitions diff --git a/backends/arm/runtime/ArmBackendEthosU.cpp b/backends/arm/runtime/ArmBackendEthosU.cpp index b0452fb9e7..99ce0a9df2 100644 --- a/backends/arm/runtime/ArmBackendEthosU.cpp +++ b/backends/arm/runtime/ArmBackendEthosU.cpp @@ -115,7 +115,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { ArmBackendExecuteCallbacks ArmBackend_execute_callbacks; // Command stream - we know at this point it's aligned char* data = (char*)execution_handle->processed->data(); - ET_LOG(Info, "ArmBackend::execute %p", data); + ET_LOG(Debug, "ArmBackend::execute %p", data); // Read key sections from the vela_bin_stream if (vela_bin_read(data, &handles, execution_handle->processed->size()) == @@ -295,7 +295,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface { tensor.size(1) == io->shape[3] && tensor.size(2) == io->shape[1] && tensor.size(3) == io->shape[2]; if (permuted_shape) { - ET_LOG(Info, "Tensor input/output %d will be permuted", index); + ET_LOG(Debug, "Tensor input/output %d will be permuted", index); } if (permuted_io_flag != permuted_shape) { ET_LOG( diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py new file mode 100644 index 0000000000..73860dfa4a --- /dev/null +++ b/backends/arm/test/ops/test_sum.py @@ -0,0 +1,129 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + +exampledata_t = Tuple[torch.Tensor, int | list[int], bool] +"""(data, dim(s), keepdim)""" + + +class TestSum(unittest.TestCase): + """Tests sum which sums all elements along some specified dimensions. + keepdim specifies whether the dimension that is summed should + be squeezed or not. + """ + + class Sum(torch.nn.Module): + test_parameters: list[Tuple[exampledata_t]] = [ + ((torch.rand(10), 0, True),), + ((torch.rand(10, 10), 1, False),), + ((torch.rand(10, 10, 10), [-3, 1], True),), + ((torch.rand(2, 1, 5, 8), 1, False),), + ((torch.rand(1, 2, 3, 4), 3, True),), + ((torch.rand(1, 2, 8, 8), [2, 3, 0], True),), + ] + + def forward(self, x: torch.Tensor, dim: int, keepdim: bool): + return x.sum(dim=dim, keepdim=keepdim) + + _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( + _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + ) + + def _test_sum_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: tuple[exampledata_t] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check_count({"torch.ops.aten.sum.dim_IntList": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_sum_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: tuple[exampledata_t] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.sum.dim_IntList": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_sum_ethosu_BI_pipeline( + self, + module: torch.nn.Module, + test_data: tuple[exampledata_t], + compile_spec: CompileSpec, + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check_count({"torch.ops.aten.sum.dim_IntList": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + + @parameterized.expand(Sum.test_parameters) + def test_sum_tosa_MI(self, test_data: tuple[exampledata_t]): + self._test_sum_tosa_MI_pipeline(self.Sum(), test_data) + + @parameterized.expand(Sum.test_parameters) + def test_sum_tosa_BI(self, test_data: tuple[exampledata_t]): + self._test_sum_tosa_BI_pipeline(self.Sum(), test_data) + + @parameterized.expand(Sum.test_parameters) + def test_sum_u55_BI(self, test_data: tuple[exampledata_t]): + self._test_sum_ethosu_BI_pipeline( + self.Sum(), + test_data, + common.get_u55_compile_spec(permute_memory_to_nhwc=False), + ) + + @parameterized.expand(Sum.test_parameters) + def test_sum_u85_BI(self, test_data: tuple[exampledata_t]): + self._test_sum_ethosu_BI_pipeline( + self.Sum(), + test_data, + common.get_u85_compile_spec(permute_memory_to_nhwc=True), + )