diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 91983e3e6b..8c20b06223 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2242,7 +2242,12 @@ def aten_ops_avg_pool( @dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default) -def aten_ops_adaptive_avg_pool( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_adaptive_avg_pool1d( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -2259,6 +2264,32 @@ def aten_ops_adaptive_avg_pool( ) +@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default) +@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default) +@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default) +@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_adaptive_avg_poolNd( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pool.adaptive_avg_poolNd( + ctx, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + output_size=args[1], + ) + + def max_pool_param_validator(pool_node: Node) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index 8c16f59030..c21ccc1c59 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -6,7 +6,10 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple +from torch_tensorrt.dynamo.conversion.converter_utils import ( + extend_attr_to_tuple, + get_positive_dim, +) from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, set_layer_name, @@ -169,3 +172,228 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int: output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1) return output + + +def adaptive_avg_poolNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + output_size: Sequence[int], +) -> TRTTensor: + input_shape = input.shape + input_rank = len(input_shape) + output_rank = len(output_size) + need_reshape_back = False + + if input_rank == output_rank + 1: # reshape to 4D/5D for TRT pooling + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape) + ) + need_reshape_back = True + input_shape = input.shape + input_rank = len(input_shape) + + extend_len = len(output_size) + output_size = list(output_size) + original_input = input + + # repeat_interleave the input if the dim of output is larger than input + insert_axises = [] + for axis in range(1, extend_len + 1): + axis = -axis + positive_axis = get_positive_dim( + axis, input_rank + ) # convert to positive axis, which is for calculating new shapes below + input_dim = input_shape[axis] + output_dim = output_size[axis] + diff = output_dim - input_dim + if diff > 0: # the dim of output is larger than input + times = output_dim // input_dim + remainder = output_dim % input_dim + if ( + diff == 2 and remainder == 2 + ): # case 1: output_dim - input_dim == 2 and is not an integral multiple + insert_axises.append(axis) + remainder -= 1 + output_size[axis] -= 1 + + if ( + remainder + 1 == input_dim + ): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input + remainder = 0 + times += 1 + + flags = [] # record the axis that needs to be repeated + concat_list = [] + for j in range( + input_dim + ): # iterate the input dim to see which dim needs to be repeated or not + single_elem = impl.select.select( + ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j + ) + new_shape = list(single_elem.shape) + new_shape.insert(positive_axis, 1) + single_elem = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_{axis}_{j}", + single_elem, + new_shape, + ) + if remainder > 0 or j in flags: + concat_list.extend([single_elem] * (times + 1)) + remainder -= 2 + flags.append(input_dim - j - 1) + else: + concat_list.extend([single_elem] * times) + out = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}_{j}", concat_list, axis + ) + input = out + + stride = tuple( + input.shape[-extend_len + i] // output_size[i] for i in range(extend_len) + ) + kernel_size = tuple( + input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i] + for i in range(extend_len) + ) + + # Don't have to pool, directly return + if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size): + if need_reshape_back: # reshape back + input = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_back", + input, + (*input.shape[1:],), + ) + return input + + layer = ctx.net.add_pooling_nd( + input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size + ) + layer.stride_nd = stride + set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir) + + output = layer.get_output(0) + + # For case 1, we need to split the output and insert the mid of input + for axis in insert_axises: + positive_axis = get_positive_dim(axis, input_rank) + input_dim = input_shape[axis] + output_dim = output_size[axis] + if input_dim % 2 == 1: + prev_one = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_prev_one_{axis}", + output, + axis, + output_dim // 2 - 1, + ) + extend_shape = list(prev_one.shape) + extend_shape.insert(positive_axis, 1) + prev_one = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_extend_shape_{axis}", + prev_one, + extend_shape, + ) + prev_two = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_prev_two_{axis}", + output, + axis, + output_dim // 2 - 2, + ) + prev_two = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_two_shape_reshape_{axis}", + prev_two, + extend_shape, + ) + prev_one_two_diff = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_prev_one_two_diff_{axis}", + prev_one, + prev_two, + ) + + mid = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_mid_{axis}", + prev_one, + prev_one_two_diff, + ) + split_output = impl.split.split( + ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis + ) + split_output.insert(1, mid) + output = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis + ) + else: + mid1 = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_{axis}", + original_input, + axis, + input_dim // 2 - 1, + ) + new_shape = list(mid1.shape) + new_shape.insert(positive_axis, 1) + mid1 = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape + ) + mid2 = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_{axis}", + original_input, + axis, + input_dim // 2, + ) + mid2 = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape + ) + split_output = impl.split.split( + ctx, + target, + source_ir, + f"{name}_split_{axis}", + output, + [output_dim // 2, 1, output_dim // 2], + axis, + ) + split_output[1] = mid1 + split_output.insert(2, mid2) + output = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis + ) + + if need_reshape_back: # reshape back + output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],) + ) + + return output diff --git a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py index 3d48409631..b8dc1e1968 100644 --- a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py +++ b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py @@ -76,10 +76,225 @@ def forward(self, x): self.run_test( TestModule(), inputs, - # use_dynamo_tracer=True, enable_passes=True, ) + @parameterized.expand( + [ + # 3d input + ( + (1, 2, 3), + (1, 2), + ), + ( + (1, 2, 3), + (2, 3), + ), + ( + (1, 2, 8), + (4, 4), + ), + ( + (2, 3, 2), + (5, 3), + ), + ( + (2, 8, 16), + (4, 8), + ), + ( + (2, 8, 16), + (8, 8), + ), + # 4d input + ( + (1, 1, 4, 3), + (4, 8), + ), + ( + (3, 2, 3, 2), + (1, 5), + ), + ( + (4, 2, 2, 8), + (5, 2), + ), + ( + (3, 2, 3, 3), + (6, 4), + ), + ( + (1, 2, 3, 2), + (2, 2), + ), + ( + (2, 2, 32, 16), + (8, 8), + ), + ( + (2, 2, 32, 32), + (31, 16), + ), + ( + (1, 1, 64, 64), + (64, 16), + ), + ] + ) + def test_adaptive_avg_pool2d( + self, + input_shape, + output_size, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.adaptive_avg_pool2d.default(x, output_size) + + inputs = [torch.randn(input_shape)] + self.run_test( + TestModule(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ((1, 2),), + ] + ) + def test_adaptive_avg_pool2d_dynamic(self, output_size): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + out = torch.ops.aten.adaptive_avg_pool2d.default(x, output_size) + return out + + input_specs = [ + Input( + shape=(-1, 2, 3, 2), + dtype=torch.float32, + shape_ranges=[((1, 2, 3, 2), (3, 2, 3, 2), (10, 2, 3, 2))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + ) + + @parameterized.expand( + [ + # 4d input + ( + (1, 1, 4, 3), + (4, 8, 2), + ), + ( + (1, 2, 3, 1), + (1, 5, 2), + ), + ( + (1, 2, 3, 2), + (1, 5, 3), + ), + ( + (4, 2, 2, 8), + (8, 5, 2), + ), + ( + (3, 2, 3, 3), + (6, 4, 1), + ), + ( + (1, 2, 3, 2), + (2, 2, 2), + ), + ( + (2, 2, 32, 16), + (8, 8, 8), + ), + ( + (2, 2, 32, 32), + (31, 16, 64), + ), + ( + (1, 1, 64, 64), + (64, 16, 1), + ), + # 5d input + ( + (1, 1, 1, 4, 3), + (4, 8, 2), + ), + ( + (4, 3, 1, 2, 3), + (2, 4, 6), + ), + ( + (1, 4, 2, 2, 2), + (5, 2, 4), + ), + ( + (3, 2, 3, 3, 2), + (6, 4, 1), + ), + ( + (2, 2, 32, 16, 8), + (8, 8, 8), + ), + ( + (2, 2, 32, 32, 32), + (31, 16, 64), + ), + ( + (1, 1, 64, 64, 64), + (64, 16, 1), + ), + ] + ) + def test_adaptive_avgpool3d( + self, + input_shape, + output_size, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.adaptive_avg_pool3d.default(x, output_size) + + inputs = [torch.randn(input_shape)] + self.run_test( + TestModule(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ((1, 2, 3),), + ] + ) + def test_adaptive_avg_pool3d_dynamic(self, output_size): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + out = torch.ops.aten.adaptive_avg_pool3d.default(x, output_size) + return out + + input_specs = [ + Input( + shape=(-1, 2, 3, 1, 4), + dtype=torch.float32, + shape_ranges=[((1, 2, 3, 1, 4), (3, 2, 3, 1, 4), (10, 2, 3, 1, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + ) + if __name__ == "__main__": run_tests()