Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support adaptive avg pool 2d and 3d dynamo converters #2632

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,7 +2185,7 @@ def aten_ops_avg_pool(


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
def aten_ops_adaptive_avg_pool(
def aten_ops_adaptive_avg_pool1d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -2202,6 +2202,27 @@ def aten_ops_adaptive_avg_pool(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
def aten_ops_adaptive_avg_poolNd(
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.adaptive_avg_poolNd(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
output_size=args[1],
)


def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)
Expand Down
230 changes: 229 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/pool.py
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
from torch_tensorrt.dynamo.conversion.converter_utils import (
extend_attr_to_tuple,
get_positive_dim,
)
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
Expand Down Expand Up @@ -169,3 +172,228 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:

output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1)
return output


def adaptive_avg_poolNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
output_size: Sequence[int],
) -> TRTTensor:
input_shape = input.shape
input_rank = len(input_shape)
output_rank = len(output_size)
need_reshape_back = False

if input_rank == output_rank + 1: # reshape to 4D/5D for TRT pooling
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape)
)
need_reshape_back = True
input_shape = input.shape
input_rank = len(input_shape)

extend_len = len(output_size)
output_size = list(output_size)
original_input = input

# repeat_interleave the input if the dim of output is larger than input
insert_axises = []
for axis in range(1, extend_len + 1):
axis = -axis
positive_axis = get_positive_dim(
axis, input_rank
) # convert to positive axis, which is for calculating new shapes below
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
input_dim = input_shape[axis]
output_dim = output_size[axis]
diff = output_dim - input_dim
if diff > 0: # the dim of output is larger than input
times = output_dim // input_dim
remainder = output_dim % input_dim
if (
diff == 2 and remainder == 2
): # case 1: output_dim - input_dim == 2 and is not an integral multiple
insert_axises.append(axis)
remainder -= 1
output_size[axis] -= 1

if (
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
remainder + 1 == input_dim
): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input
remainder = 0
times += 1

flags = [] # record the axis that needs to be repeated
concat_list = []
for j in range(
input_dim
): # iterate the input dim to see which dim needs to be repeated or not
single_elem = impl.select.select(
ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j
)
new_shape = list(single_elem.shape)
new_shape.insert(positive_axis, 1)
single_elem = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_{axis}_{j}",
single_elem,
new_shape,
)
if remainder > 0 or j in flags:
concat_list.extend([single_elem] * (times + 1))
remainder -= 2
flags.append(input_dim - j - 1)
else:
concat_list.extend([single_elem] * times)
out = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}_{j}", concat_list, axis
)
input = out

stride = tuple(
input.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
)
kernel_size = tuple(
input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
for i in range(extend_len)
)

# Don't have to pool, directly return
if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size):
if need_reshape_back: # reshape back
input = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_back",
input,
(*input.shape[1:],),
)
return input

layer = ctx.net.add_pooling_nd(
input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size
)
layer.stride_nd = stride
set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir)

output = layer.get_output(0)

# For case 1, we need to split the output and insert the mid of input
for axis in insert_axises:
positive_axis = get_positive_dim(axis, input_rank)
input_dim = input_shape[axis]
output_dim = output_size[axis]
if input_dim % 2 == 1:
prev_one = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_prev_one_{axis}",
output,
axis,
output_dim // 2 - 1,
)
extend_shape = list(prev_one.shape)
extend_shape.insert(positive_axis, 1)
prev_one = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_extend_shape_{axis}",
prev_one,
extend_shape,
)
prev_two = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_prev_two_{axis}",
output,
axis,
output_dim // 2 - 2,
)
prev_two = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_two_shape_reshape_{axis}",
prev_two,
extend_shape,
)
prev_one_two_diff = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_prev_one_two_diff_{axis}",
prev_one,
prev_two,
)

mid = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_mid_{axis}",
prev_one,
prev_one_two_diff,
)
split_output = impl.split.split(
ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis
)
split_output.insert(1, mid)
output = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)
else:
mid1 = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2 - 1,
)
new_shape = list(mid1.shape)
new_shape.insert(positive_axis, 1)
mid1 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape
)
mid2 = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2,
)
mid2 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape
)
split_output = impl.split.split(
ctx,
target,
source_ir,
f"{name}_split_{axis}",
output,
[output_dim // 2, 1, output_dim // 2],
axis,
)
split_output[1] = mid1
split_output.insert(2, mid2)
output = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)

if need_reshape_back: # reshape back
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],)
)

return output
Loading
Loading