Skip to content

Commit

Permalink
feat: support adaptive avg pool 2d and 3d dynamo converters (#2632)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored and laikhtewari committed May 24, 2024
1 parent 182344a commit 1e56b61
Show file tree
Hide file tree
Showing 3 changed files with 477 additions and 3 deletions.
33 changes: 32 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2226,7 +2226,12 @@ def aten_ops_avg_pool(


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
def aten_ops_adaptive_avg_pool(
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_adaptive_avg_pool1d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -2243,6 +2248,32 @@ def aten_ops_adaptive_avg_pool(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_adaptive_avg_poolNd(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.adaptive_avg_poolNd(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
output_size=args[1],
)


def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)
Expand Down
230 changes: 229 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
from torch_tensorrt.dynamo.conversion.converter_utils import (
extend_attr_to_tuple,
get_positive_dim,
)
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
Expand Down Expand Up @@ -169,3 +172,228 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:

output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1)
return output


def adaptive_avg_poolNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
output_size: Sequence[int],
) -> TRTTensor:
input_shape = input.shape
input_rank = len(input_shape)
output_rank = len(output_size)
need_reshape_back = False

if input_rank == output_rank + 1: # reshape to 4D/5D for TRT pooling
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape)
)
need_reshape_back = True
input_shape = input.shape
input_rank = len(input_shape)

extend_len = len(output_size)
output_size = list(output_size)
original_input = input

# repeat_interleave the input if the dim of output is larger than input
insert_axises = []
for axis in range(1, extend_len + 1):
axis = -axis
positive_axis = get_positive_dim(
axis, input_rank
) # convert to positive axis, which is for calculating new shapes below
input_dim = input_shape[axis]
output_dim = output_size[axis]
diff = output_dim - input_dim
if diff > 0: # the dim of output is larger than input
times = output_dim // input_dim
remainder = output_dim % input_dim
if (
diff == 2 and remainder == 2
): # case 1: output_dim - input_dim == 2 and is not an integral multiple
insert_axises.append(axis)
remainder -= 1
output_size[axis] -= 1

if (
remainder + 1 == input_dim
): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input
remainder = 0
times += 1

flags = [] # record the axis that needs to be repeated
concat_list = []
for j in range(
input_dim
): # iterate the input dim to see which dim needs to be repeated or not
single_elem = impl.select.select(
ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j
)
new_shape = list(single_elem.shape)
new_shape.insert(positive_axis, 1)
single_elem = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_{axis}_{j}",
single_elem,
new_shape,
)
if remainder > 0 or j in flags:
concat_list.extend([single_elem] * (times + 1))
remainder -= 2
flags.append(input_dim - j - 1)
else:
concat_list.extend([single_elem] * times)
out = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}_{j}", concat_list, axis
)
input = out

stride = tuple(
input.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
)
kernel_size = tuple(
input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
for i in range(extend_len)
)

# Don't have to pool, directly return
if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size):
if need_reshape_back: # reshape back
input = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_back",
input,
(*input.shape[1:],),
)
return input

layer = ctx.net.add_pooling_nd(
input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size
)
layer.stride_nd = stride
set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir)

output = layer.get_output(0)

# For case 1, we need to split the output and insert the mid of input
for axis in insert_axises:
positive_axis = get_positive_dim(axis, input_rank)
input_dim = input_shape[axis]
output_dim = output_size[axis]
if input_dim % 2 == 1:
prev_one = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_prev_one_{axis}",
output,
axis,
output_dim // 2 - 1,
)
extend_shape = list(prev_one.shape)
extend_shape.insert(positive_axis, 1)
prev_one = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_extend_shape_{axis}",
prev_one,
extend_shape,
)
prev_two = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_prev_two_{axis}",
output,
axis,
output_dim // 2 - 2,
)
prev_two = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_two_shape_reshape_{axis}",
prev_two,
extend_shape,
)
prev_one_two_diff = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_prev_one_two_diff_{axis}",
prev_one,
prev_two,
)

mid = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_mid_{axis}",
prev_one,
prev_one_two_diff,
)
split_output = impl.split.split(
ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis
)
split_output.insert(1, mid)
output = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)
else:
mid1 = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2 - 1,
)
new_shape = list(mid1.shape)
new_shape.insert(positive_axis, 1)
mid1 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape
)
mid2 = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2,
)
mid2 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape
)
split_output = impl.split.split(
ctx,
target,
source_ir,
f"{name}_split_{axis}",
output,
[output_dim // 2, 1, output_dim // 2],
axis,
)
split_output[1] = mid1
split_output.insert(2, mid2)
output = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)

if need_reshape_back: # reshape back
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],)
)

return output
Loading

0 comments on commit 1e56b61

Please sign in to comment.