From 35fdf8b16c3cad396dc2d21efe2bc0fc871a2285 Mon Sep 17 00:00:00 2001 From: Krishna Bindumadhavan <31140965+f2013519@users.noreply.github.com> Date: Mon, 9 Sep 2024 00:33:12 +0530 Subject: [PATCH] [relay][qnn]: Fix qnn.avg_pool2d layout inference (#17339) --- src/relay/qnn/op/avg_pool2d.cc | 8 +- .../relay/test_pass_convert_op_layout.py | 79 +++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/relay/qnn/op/avg_pool2d.cc b/src/relay/qnn/op/avg_pool2d.cc index b2dc08b85686..e1a28169ccda 100644 --- a/src/relay/qnn/op/avg_pool2d.cc +++ b/src/relay/qnn/op/avg_pool2d.cc @@ -132,9 +132,11 @@ InferCorrectLayoutOutput QnnAvgPoolInferCorrectLayout(const Attrs& attrs, auto avgpool_new_layouts = PoolInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); - // Scales and zero points are scalars, use the "undef" layout for them. - Array input_layouts = {avgpool_new_layouts->input_layouts[0], Layout::Undef(), - Layout::Undef(), Layout::Undef(), Layout::Undef()}; + // Scales and zero points are scalars, the layouts of these tensors can be treated as channel + // layout. + Layout channel_layout = Layout("C"); + Array input_layouts = {avgpool_new_layouts->input_layouts[0], channel_layout, + channel_layout, channel_layout, channel_layout}; Array output_layouts = avgpool_new_layouts->output_layouts; return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs); } diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 49afe492a121..5450f1aa6906 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1542,6 +1542,85 @@ def expected(): tvm.ir.assert_structural_equal(a, b) +def test_qnn_conv_avgpool_2d_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + y = relay.qnn.op.conv2d( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.cast(y, "int8") + y = relay.qnn.op.avg_pool2d( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + layout="NHWC", + out_layout="NHWC", + pool_size=(3, 3), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + x = relay.layout_transform(x, "NHWC", "NCHW") + weight = relay.layout_transform(weight, "HWIO", "OIHW") + y = relay.qnn.op.conv2d( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.cast(y, "int8") + y = relay.qnn.op.avg_pool2d( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + layout="NCHW", + out_layout="NCHW", + pool_size=(3, 3), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass( + a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"], "qnn.avg_pool2d": ["NCHW"]}) + ) + b = run_opt_pass(expected(), transform.InferType()) + + tvm.ir.assert_structural_equal(a, b) + + def test_conv_roi_align_convert_layout(): def before(): x = relay.var("x", shape=(1, 64, 56, 56))