Skip to content

Commit

Permalink
feat: support aten.pixel_shuffle dynamo converter
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Jan 20, 2024
1 parent 4b608f0 commit e2c7081
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,29 @@ def aten_ops_reshape(
)


@dynamo_tensorrt_converter(torch.ops.aten.pixel_shuffle.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_pixel_shuffle(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.pixel_shuffle(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
def aten_ops_argmax(
Expand Down
41 changes: 41 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Sequence, Union

import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
Expand All @@ -19,3 +20,43 @@ def reshape(
layer.reshape_dims = tuple(shape)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def pixel_shuffle(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
upscale_factor: int,
) -> TRTTensor:
shape = input.shape
in_channels, in_height, in_width = shape[-3:]
out_channels = in_channels // (upscale_factor**2)
out_height = in_height * upscale_factor
out_width = in_width * upscale_factor
new_shape = shape[:-3] + (
out_channels,
upscale_factor,
upscale_factor,
in_height,
in_width,
)
reshaped_tensor = reshape(
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
)
rank = len(shape)
permute_shape = list(range(rank))
permute_shape.insert(-2, rank)
permute_shape.insert(-1, rank + 1)
permuted_tensor = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
)
return reshape(
ctx,
target,
source_ir,
f"{name}_reshape2",
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
)
31 changes: 31 additions & 0 deletions tests/py/dynamo/conversion/test_pixel_shuffle_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestPixelShuffleConverter(DispatchTestCase):
@parameterized.expand(
[
((1, 1, 1), 1),
((12, 3, 4), 2),
((1, 9, 4, 4), 3),
((2, 32, 2, 3), 4),
((1, 10, 36, 2, 4), 6),
]
)
def test_pixel_shuffle(self, shape, upscale_factor):
class PixelShuffle(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.pixel_shuffle.default(x, upscale_factor)

inputs = [torch.randn(shape)]
self.run_test(
PixelShuffle(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit e2c7081

Please sign in to comment.