Skip to content

Commit

Permalink
combine two pools
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Apr 9, 2024
1 parent cbc09ae commit 61b74f3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 247 deletions.
21 changes: 2 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,33 +2204,16 @@ def aten_ops_adaptive_avg_pool1d(

@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
def aten_ops_adaptive_avg_pool2d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.adaptive_avg_pool2d(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
output_size=args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
def aten_ops_adaptive_avg_pool3d(
def aten_ops_adaptive_avg_poolNd(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.adaptive_avg_pool3d(
return impl.pool.adaptive_avg_poolNd(
ctx,
target,
source_ir=SourceIR.ATEN,
Expand Down
233 changes: 5 additions & 228 deletions py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
return output


def adaptive_avg_pool2d(
def adaptive_avg_poolNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
Expand All @@ -184,9 +184,10 @@ def adaptive_avg_pool2d(
) -> TRTTensor:
input_shape = input.shape
input_rank = len(input_shape)
output_rank = len(output_size)
need_reshape_back = False

if input_rank == 3: # reshape to 4D for TRT pooling
if input_rank == output_rank + 1: # reshape to 4D/5D for TRT pooling
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape)
)
Expand All @@ -198,230 +199,6 @@ def adaptive_avg_pool2d(
output_size = list(output_size)
original_input = input

# repeat_interleave the input if the dim of output is larger than input
insert_axises = []
for axis in range(1, extend_len + 1):
axis = -axis
positive_axis = get_positive_dim(
axis, input_rank
) # convert to positive axis, which is for calculating new shapes below
input_dim = input_shape[axis]
output_dim = output_size[axis]
diff = output_dim - input_dim
if diff > 0: # the dim of output is larger than input
times = output_dim // input_dim
remainder = output_dim % input_dim
if (
diff == 2 and remainder == 2
): # case 1: output_dim - input_dim == 2 and is not an integral multiple
insert_axises.append(axis)
remainder -= 1
output_size[axis] -= 1

if (
remainder + 1 == input_dim
): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input
remainder = 0
times += 1

flags = [] # record the axis that needs to be repeated
concat_list = []
for j in range(
input_dim
): # iterate the input dim to see which dim needs to be repeated or not
single_elem = impl.select.select(
ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j
)
new_shape = list(single_elem.shape)
new_shape.insert(positive_axis, 1)
single_elem = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_{axis}_{j}",
single_elem,
new_shape,
)
if remainder > 0 or j in flags:
concat_list.extend([single_elem] * (times + 1))
remainder -= 2
flags.append(input_dim - j - 1)
else:
concat_list.extend([single_elem] * times)
out = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", concat_list, axis
)
input = out

stride = tuple(
input.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
)
kernel_size = tuple(
input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
for i in range(extend_len)
)

# Don't have to pool, directly return
if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size):
if need_reshape_back: # reshape back to 3D
input = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_back",
input,
(*input.shape[1:],),
)
return input

layer = ctx.net.add_pooling_nd(
input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size
)
layer.stride_nd = stride
set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir)

output = layer.get_output(0)

# For case 1, we need to split the output and insert the mid of input
for axis in insert_axises:
positive_axis = get_positive_dim(axis, input_rank)
input_dim = input_shape[axis]
output_dim = output_size[axis]
if input_dim % 2 == 1:
prev_one = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_prev_one_{axis}",
output,
axis,
output_dim // 2 - 1,
)
extend_shape = list(prev_one.shape)
extend_shape.insert(positive_axis, 1)
prev_one = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_extend_shape_{axis}",
prev_one,
extend_shape,
)
prev_two = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_prev_two_{axis}",
output,
axis,
output_dim // 2 - 2,
)
prev_two = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_two_shape_reshape_{axis}",
prev_two,
extend_shape,
)
prev_one_two_diff = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_prev_one_two_diff_{axis}",
prev_one,
prev_two,
)

mid = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_mid_{axis}",
prev_one,
prev_one_two_diff,
)
split_output = impl.split.split(
ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis
)
split_output.insert(1, mid)
output = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)
else:
mid1 = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2 - 1,
)
new_shape = list(mid1.shape)
new_shape.insert(positive_axis, 1)
mid1 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape
)
mid2 = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2,
)
mid2 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape
)
split_output = impl.split.split(
ctx,
target,
source_ir,
f"{name}_split_{axis}",
output,
[output_dim // 2, 1, output_dim // 2],
axis,
)
split_output[1] = mid1
split_output.insert(2, mid2)
output = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)

if need_reshape_back: # reshape back to 3D
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],)
)

return output


def adaptive_avg_pool3d(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
output_size: Sequence[int],
) -> TRTTensor:
input_shape = input.shape
input_rank = len(input_shape)
need_reshape_back = False

if input_rank == 4: # reshape to 5D for TRT pooling
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, (1, *input_shape)
)
need_reshape_back = True
input_shape = input.shape
input_rank = len(input_shape)

extend_len = len(output_size)
output_size = list(output_size)
original_input = input

# repeat_interleave the input if the dim of output is larger than input
insert_axises = []
for axis in range(1, extend_len + 1):
Expand Down Expand Up @@ -487,7 +264,7 @@ def adaptive_avg_pool3d(

# Don't have to pool, directly return
if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size):
if need_reshape_back: # reshape back to 4D
if need_reshape_back: # reshape back
input = impl.shuffle.reshape(
ctx,
target,
Expand Down Expand Up @@ -614,7 +391,7 @@ def adaptive_avg_pool3d(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)

if need_reshape_back: # reshape back to 4D
if need_reshape_back: # reshape back
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],)
)
Expand Down

0 comments on commit 61b74f3

Please sign in to comment.