diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e847f3a298..3f8a9661f8 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2204,33 +2204,16 @@ def aten_ops_adaptive_avg_pool1d( @dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default) @dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default) -def aten_ops_adaptive_avg_pool2d( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.pool.adaptive_avg_pool2d( - ctx, - target, - source_ir=SourceIR.ATEN, - name=name, - input=args[0], - output_size=args[1], - ) - - @dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default) @dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default) -def aten_ops_adaptive_avg_pool3d( +def aten_ops_adaptive_avg_poolNd( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.pool.adaptive_avg_pool3d( + return impl.pool.adaptive_avg_poolNd( ctx, target, source_ir=SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index 137cd8c9ea..c21ccc1c59 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -174,7 +174,7 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int: return output -def adaptive_avg_pool2d( +def adaptive_avg_poolNd( ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], @@ -184,9 +184,10 @@ def adaptive_avg_pool2d( ) -> TRTTensor: input_shape = input.shape input_rank = len(input_shape) + output_rank = len(output_size) need_reshape_back = False - if input_rank == 3: # reshape to 4D for TRT pooling + if input_rank == output_rank + 1: # reshape to 4D/5D for TRT pooling input = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape) ) @@ -198,230 +199,6 @@ def adaptive_avg_pool2d( output_size = list(output_size) original_input = input - # repeat_interleave the input if the dim of output is larger than input - insert_axises = [] - for axis in range(1, extend_len + 1): - axis = -axis - positive_axis = get_positive_dim( - axis, input_rank - ) # convert to positive axis, which is for calculating new shapes below - input_dim = input_shape[axis] - output_dim = output_size[axis] - diff = output_dim - input_dim - if diff > 0: # the dim of output is larger than input - times = output_dim // input_dim - remainder = output_dim % input_dim - if ( - diff == 2 and remainder == 2 - ): # case 1: output_dim - input_dim == 2 and is not an integral multiple - insert_axises.append(axis) - remainder -= 1 - output_size[axis] -= 1 - - if ( - remainder + 1 == input_dim - ): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input - remainder = 0 - times += 1 - - flags = [] # record the axis that needs to be repeated - concat_list = [] - for j in range( - input_dim - ): # iterate the input dim to see which dim needs to be repeated or not - single_elem = impl.select.select( - ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j - ) - new_shape = list(single_elem.shape) - new_shape.insert(positive_axis, 1) - single_elem = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_{axis}_{j}", - single_elem, - new_shape, - ) - if remainder > 0 or j in flags: - concat_list.extend([single_elem] * (times + 1)) - remainder -= 2 - flags.append(input_dim - j - 1) - else: - concat_list.extend([single_elem] * times) - out = impl.cat.cat( - ctx, target, source_ir, f"{name}_cat_{axis}", concat_list, axis - ) - input = out - - stride = tuple( - input.shape[-extend_len + i] // output_size[i] for i in range(extend_len) - ) - kernel_size = tuple( - input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i] - for i in range(extend_len) - ) - - # Don't have to pool, directly return - if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size): - if need_reshape_back: # reshape back to 3D - input = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_back", - input, - (*input.shape[1:],), - ) - return input - - layer = ctx.net.add_pooling_nd( - input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size - ) - layer.stride_nd = stride - set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir) - - output = layer.get_output(0) - - # For case 1, we need to split the output and insert the mid of input - for axis in insert_axises: - positive_axis = get_positive_dim(axis, input_rank) - input_dim = input_shape[axis] - output_dim = output_size[axis] - if input_dim % 2 == 1: - prev_one = impl.select.select( - ctx, - target, - source_ir, - f"{name}_select_prev_one_{axis}", - output, - axis, - output_dim // 2 - 1, - ) - extend_shape = list(prev_one.shape) - extend_shape.insert(positive_axis, 1) - prev_one = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_extend_shape_{axis}", - prev_one, - extend_shape, - ) - prev_two = impl.select.select( - ctx, - target, - source_ir, - f"{name}_select_prev_two_{axis}", - output, - axis, - output_dim // 2 - 2, - ) - prev_two = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_two_shape_reshape_{axis}", - prev_two, - extend_shape, - ) - prev_one_two_diff = impl.elementwise.sub( - ctx, - target, - source_ir, - f"{name}_prev_one_two_diff_{axis}", - prev_one, - prev_two, - ) - - mid = impl.elementwise.add( - ctx, - target, - source_ir, - f"{name}_mid_{axis}", - prev_one, - prev_one_two_diff, - ) - split_output = impl.split.split( - ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis - ) - split_output.insert(1, mid) - output = impl.cat.cat( - ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis - ) - else: - mid1 = impl.select.select( - ctx, - target, - source_ir, - f"{name}_select_{axis}", - original_input, - axis, - input_dim // 2 - 1, - ) - new_shape = list(mid1.shape) - new_shape.insert(positive_axis, 1) - mid1 = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape - ) - mid2 = impl.select.select( - ctx, - target, - source_ir, - f"{name}_select_{axis}", - original_input, - axis, - input_dim // 2, - ) - mid2 = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape - ) - split_output = impl.split.split( - ctx, - target, - source_ir, - f"{name}_split_{axis}", - output, - [output_dim // 2, 1, output_dim // 2], - axis, - ) - split_output[1] = mid1 - split_output.insert(2, mid2) - output = impl.cat.cat( - ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis - ) - - if need_reshape_back: # reshape back to 3D - output = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],) - ) - - return output - - -def adaptive_avg_pool3d( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - output_size: Sequence[int], -) -> TRTTensor: - input_shape = input.shape - input_rank = len(input_shape) - need_reshape_back = False - - if input_rank == 4: # reshape to 5D for TRT pooling - input = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape", input, (1, *input_shape) - ) - need_reshape_back = True - input_shape = input.shape - input_rank = len(input_shape) - - extend_len = len(output_size) - output_size = list(output_size) - original_input = input - # repeat_interleave the input if the dim of output is larger than input insert_axises = [] for axis in range(1, extend_len + 1): @@ -487,7 +264,7 @@ def adaptive_avg_pool3d( # Don't have to pool, directly return if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size): - if need_reshape_back: # reshape back to 4D + if need_reshape_back: # reshape back input = impl.shuffle.reshape( ctx, target, @@ -614,7 +391,7 @@ def adaptive_avg_pool3d( ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis ) - if need_reshape_back: # reshape back to 4D + if need_reshape_back: # reshape back output = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],) )