Skip to content

Commit

Permalink
feat: support _pdist_forward dynamo converter (#2570)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Jan 14, 2024
1 parent 8a7fb23 commit b8403b8
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,3 +2603,26 @@ def aten_ops_remainder(
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_pdist(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.pdist(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args_bounds_check(args, 1, 2),
)
68 changes: 68 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_positive_dim,
get_trt_tensor,
to_numpy,
Expand Down Expand Up @@ -440,3 +441,70 @@ def get_softmax_dim(ndim: int) -> int:
layer.axes = 1 << dim
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def pdist(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
p: float = 2,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
shape = input.shape
extend_input = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape",
input,
shape=shape[0:1] + (1,) + shape[1:],
)
x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", extend_input, input)

if p == 0:
# norm = torch.sum(x!=0, dim=2)
nonzero_val = impl.elementwise.ne(ctx, target, source_ir, f"{name}_ne", x, 0)
norm = impl.reduce.sum(
ctx, target, source_ir, f"{name}_sum", nonzero_val, dim=2, keepdim=False
)
norm = cast_trt_tensor(
ctx, norm, torch.float32, f"{name}_cast", target, source_ir
)
elif p == 1:
# norm = torch.sum(torch.abs(x), dim=2)
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x)
norm = impl.reduce.sum(
ctx, target, source_ir, f"{name}_sum", abs_val, dim=2, keepdim=False
)
elif 0 < p < 1 or 1 < p < float("inf"):
# norm = torch.pow(torch.sum(torch.pow(torch.abs(x), p), dim=2), 1/p)
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x)
pow_val = impl.elementwise.pow(
ctx, target, source_ir, f"{name}_pow1", abs_val, p
)
sum_val = impl.reduce.sum(
ctx, target, source_ir, f"{name}_sum", pow_val, dim=2, keepdim=False
)
norm = impl.elementwise.pow(
ctx, target, source_ir, f"{name}_pow2", sum_val, 1 / p
)
elif p == float("inf"):
# norm = torch.max(torch.abs(x))
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x)
norm = impl.reduce.max(
ctx,
target,
source_ir,
f"{name}_max",
abs_val,
dim=2,
keepdim=False,
return_indices=False,
)
else:
raise RuntimeError(
f"p should between [0, inf], currently p={p} is not supported!"
)
indices = np.triu_indices(shape[0], k=1)
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)
36 changes: 36 additions & 0 deletions tests/py/dynamo/conversion/test_pdist_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestPdistConverter(DispatchTestCase):
@parameterized.expand(
[
((2, 3), 0),
((2, 3), 0.4),
((2, 3), 1),
((2, 3), 1.5),
((3, 4), 2),
((3, 4), 2.99),
((4, 5), 3),
((4, 5), 3.3),
((5, 6), float("inf")),
]
)
def test_pdist_float(self, shape, p):
class Pdist(nn.Module):
def forward(self, input):
return torch.ops.aten._pdist_forward.default(input, p)

inputs = [torch.randn(shape)]
self.run_test(
Pdist(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit b8403b8

Please sign in to comment.