diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 45b6f1ad53..712e76cf3c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2603,3 +2603,26 @@ def aten_ops_remainder( args[0], args[1], ) + + +@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_pdist( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.pdist( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args_bounds_check(args, 1, 2), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 05821f8d90..f45d067349 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -8,6 +8,7 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, get_positive_dim, get_trt_tensor, to_numpy, @@ -440,3 +441,70 @@ def get_softmax_dim(ndim: int) -> int: layer.axes = 1 << dim set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def pdist( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + p: float = 2, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + shape = input.shape + extend_input = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape", + input, + shape=shape[0:1] + (1,) + shape[1:], + ) + x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", extend_input, input) + + if p == 0: + # norm = torch.sum(x!=0, dim=2) + nonzero_val = impl.elementwise.ne(ctx, target, source_ir, f"{name}_ne", x, 0) + norm = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", nonzero_val, dim=2, keepdim=False + ) + norm = cast_trt_tensor( + ctx, norm, torch.float32, f"{name}_cast", target, source_ir + ) + elif p == 1: + # norm = torch.sum(torch.abs(x), dim=2) + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x) + norm = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", abs_val, dim=2, keepdim=False + ) + elif 0 < p < 1 or 1 < p < float("inf"): + # norm = torch.pow(torch.sum(torch.pow(torch.abs(x), p), dim=2), 1/p) + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x) + pow_val = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_pow1", abs_val, p + ) + sum_val = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", pow_val, dim=2, keepdim=False + ) + norm = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_pow2", sum_val, 1 / p + ) + elif p == float("inf"): + # norm = torch.max(torch.abs(x)) + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x) + norm = impl.reduce.max( + ctx, + target, + source_ir, + f"{name}_max", + abs_val, + dim=2, + keepdim=False, + return_indices=False, + ) + else: + raise RuntimeError( + f"p should between [0, inf], currently p={p} is not supported!" + ) + indices = np.triu_indices(shape[0], k=1) + return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices) diff --git a/tests/py/dynamo/conversion/test_pdist_aten.py b/tests/py/dynamo/conversion/test_pdist_aten.py new file mode 100644 index 0000000000..67e547faf2 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pdist_aten.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestPdistConverter(DispatchTestCase): + @parameterized.expand( + [ + ((2, 3), 0), + ((2, 3), 0.4), + ((2, 3), 1), + ((2, 3), 1.5), + ((3, 4), 2), + ((3, 4), 2.99), + ((4, 5), 3), + ((4, 5), 3.3), + ((5, 6), float("inf")), + ] + ) + def test_pdist_float(self, shape, p): + class Pdist(nn.Module): + def forward(self, input): + return torch.ops.aten._pdist_forward.default(input, p) + + inputs = [torch.randn(shape)] + self.run_test( + Pdist(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()