Skip to content

Commit

Permalink
Add dynamic shape support for bitwise_and/or/xor/not, exp/expm1/recip…
Browse files Browse the repository at this point in the history
…/log/log2/log10 (#2973)
  • Loading branch information
lanluo-nvidia authored Jul 2, 2024
1 parent 9f46d39 commit feb4d84
Show file tree
Hide file tree
Showing 11 changed files with 405 additions and 17 deletions.
57 changes: 40 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ def aten_ops_mean(
)


@dynamo_tensorrt_converter(torch.ops.aten.exp.default)
@dynamo_tensorrt_converter(torch.ops.aten.exp.default, supports_dynamic_shapes=True)
def aten_ops_exp(
ctx: ConversionContext,
target: Target,
Expand All @@ -1341,7 +1341,7 @@ def aten_ops_exp(
)


@dynamo_tensorrt_converter(torch.ops.aten.expm1.default)
@dynamo_tensorrt_converter(torch.ops.aten.expm1.default, supports_dynamic_shapes=True)
def aten_ops_expm1(
ctx: ConversionContext,
target: Target,
Expand All @@ -1358,7 +1358,7 @@ def aten_ops_expm1(
)


@dynamo_tensorrt_converter(torch.ops.aten.log.default)
@dynamo_tensorrt_converter(torch.ops.aten.log.default, supports_dynamic_shapes=True)
def aten_ops_log(
ctx: ConversionContext,
target: Target,
Expand All @@ -1375,7 +1375,7 @@ def aten_ops_log(
)


@dynamo_tensorrt_converter(torch.ops.aten.log2.default)
@dynamo_tensorrt_converter(torch.ops.aten.log2.default, supports_dynamic_shapes=True)
def aten_ops_log2(
ctx: ConversionContext,
target: Target,
Expand All @@ -1392,7 +1392,7 @@ def aten_ops_log2(
)


@dynamo_tensorrt_converter(torch.ops.aten.log10.default)
@dynamo_tensorrt_converter(torch.ops.aten.log10.default, supports_dynamic_shapes=True)
def aten_ops_log10(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1443,7 +1443,9 @@ def aten_ops_sqrt(
)


@dynamo_tensorrt_converter(torch.ops.aten.reciprocal.default)
@dynamo_tensorrt_converter(
torch.ops.aten.reciprocal.default, supports_dynamic_shapes=True
)
def aten_ops_recip(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -2054,7 +2056,9 @@ def aten_ops_logical_and(
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
@dynamo_tensorrt_converter(
torch.ops.aten.logical_or.default, supports_dynamic_shapes=True
)
def aten_ops_logical_or(
ctx: ConversionContext,
target: Target,
Expand All @@ -2072,7 +2076,9 @@ def aten_ops_logical_or(
)


@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
@dynamo_tensorrt_converter(
torch.ops.aten.logical_xor.default, supports_dynamic_shapes=True
)
def aten_ops_logical_xor(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -2108,7 +2114,6 @@ def bitwise_type_validator(node: Node) -> bool:
torch.ops.aten.bitwise_or.Scalar_Tensor,
torch.ops.aten.bitwise_xor.Scalar_Tensor,
]

if node.target in tensor_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
Expand Down Expand Up @@ -2139,14 +2144,19 @@ def bitwise_type_validator(node: Node) -> bool:


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator
torch.ops.aten.bitwise_and.Tensor,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator
torch.ops.aten.bitwise_and.Scalar,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Scalar_Tensor,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
def aten_ops_bitwise_and(
ctx: ConversionContext,
Expand All @@ -2166,13 +2176,19 @@ def aten_ops_bitwise_and(


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator
torch.ops.aten.bitwise_or.Tensor,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator
torch.ops.aten.bitwise_or.Scalar,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator
torch.ops.aten.bitwise_or.Scalar_Tensor,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
def aten_ops_bitwise_or(
ctx: ConversionContext,
Expand All @@ -2192,14 +2208,19 @@ def aten_ops_bitwise_or(


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator
torch.ops.aten.bitwise_xor.Tensor,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator
torch.ops.aten.bitwise_xor.Scalar,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Scalar_Tensor,
capability_validator=bitwise_type_validator,
supports_dynamic_shapes=True,
)
def aten_ops_bitwise_xor(
ctx: ConversionContext,
Expand Down Expand Up @@ -2230,7 +2251,9 @@ def bitwise_not_type_validator(node: Node) -> bool:


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator
torch.ops.aten.bitwise_not.default,
capability_validator=bitwise_not_type_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
Expand Down
84 changes: 84 additions & 0 deletions tests/py/dynamo/conversion/test_bitwise_and_aten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
import torch.nn as nn
import torch_tensorrt
from parameterized import parameterized
from torch.export import Dim
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -32,6 +35,50 @@ def forward(self, lhs_val, rhs_val):
use_dynamo_tracer=True,
)

@parameterized.expand(
[
("2d-2d", (2, 3), (3, 3), (5, 3), (2, 3), (3, 3), (5, 3)),
("3d-3d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (1, 2, 2), (1, 3, 2), (1, 4, 2)),
]
)
def test_bitwise_and_tensor_dynamic_shape(
self,
_,
lhs_min_shape,
lhs_opt_shape,
lhs_max_shape,
rhs_min_shape,
rhs_opt_shape,
rhs_max_shape,
):
class bitwise_and(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)

inputs = [
Input(
dtype=torch.bool,
min_shape=lhs_min_shape,
opt_shape=lhs_opt_shape,
max_shape=lhs_max_shape,
torch_tensor=torch.randint(0, 2, lhs_opt_shape, dtype=bool),
),
Input(
dtype=torch.bool,
min_shape=rhs_min_shape,
opt_shape=rhs_opt_shape,
max_shape=rhs_max_shape,
torch_tensor=torch.randint(0, 2, rhs_opt_shape, dtype=bool),
),
]
self.run_test_with_dynamic_shape(
bitwise_and(),
inputs,
enable_passes=True,
use_dynamo_tracer=True,
use_example_tensors=False,
)

@parameterized.expand(
[
("2d", (5, 3), True),
Expand Down Expand Up @@ -74,6 +121,43 @@ def forward(self, tensor):
use_dynamo_tracer=True,
)

# this test case is to test the bitwise_and with different ranks
# it cannot use the normal test_with_dynamic_shape due to the
# torch_tensorrt.dynamo.trace doesn't automatically handle it
# hence has to manually export the graph and run the test.
def test_bitwise_and_dynamic_shape_with_different_ranks(self):
class bitwise_and(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)

dyn_dim = Dim("dyn_dim", min=2, max=6)
inputs = (
torch.randint(0, 2, (2, 4, 2), dtype=bool),
torch.randint(0, 2, (4, 2), dtype=bool),
)
mod = bitwise_and()
fx_mod = torch.export.export(
mod, inputs, dynamic_shapes=({1: dyn_dim}, {0: dyn_dim})
)
trt_mod = torch_tensorrt.dynamo.compile(
fx_mod, inputs=inputs, enable_precisions={torch.bool}, min_block_size=1
)
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
ref_outputs = mod(*cuda_inputs)
outputs = trt_mod(*cuda_inputs)
for out, ref in zip(outputs, ref_outputs):
torch.testing.assert_close(
out,
ref,
rtol=0.001,
atol=0.001,
equal_nan=True,
check_dtype=True,
)


if __name__ == "__main__":
run_tests()
29 changes: 29 additions & 0 deletions tests/py/dynamo/conversion/test_bitwise_not_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -28,6 +29,34 @@ def forward(self, val):
use_dynamo_tracer=True,
)

@parameterized.expand(
[
("2d", (2, 3), (5, 3), (6, 4)),
("3d", (2, 3, 2), (3, 4, 2), (5, 4, 2)),
]
)
def test_bitwise_not_tensor_dynamic_shape(self, _, min_shape, opt_shape, max_shape):
class bitwise_not(nn.Module):
def forward(self, val):
return torch.ops.aten.bitwise_not.default(val)

inputs = [
Input(
dtype=torch.bool,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
torch_tensor=torch.randint(0, 2, opt_shape, dtype=bool),
)
]
self.run_test_with_dynamic_shape(
bitwise_not(),
inputs,
enable_passes=True,
use_dynamo_tracer=True,
use_example_tensors=False,
)


if __name__ == "__main__":
run_tests()
45 changes: 45 additions & 0 deletions tests/py/dynamo/conversion/test_bitwise_or_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -32,6 +33,50 @@ def forward(self, lhs_val, rhs_val):
use_dynamo_tracer=True,
)

@parameterized.expand(
[
("2d-2d", (2, 3), (3, 3), (5, 3), (2, 3), (3, 3), (5, 3)),
("3d-3d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (1, 2, 2), (1, 3, 2), (1, 4, 2)),
]
)
def test_bitwise_or_tensor_dynamic_shape(
self,
_,
lhs_min_shape,
lhs_opt_shape,
lhs_max_shape,
rhs_min_shape,
rhs_opt_shape,
rhs_max_shape,
):
class bitwise_or(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val)

inputs = [
Input(
dtype=torch.bool,
min_shape=lhs_min_shape,
opt_shape=lhs_opt_shape,
max_shape=lhs_max_shape,
torch_tensor=torch.randint(0, 2, lhs_opt_shape, dtype=bool),
),
Input(
dtype=torch.bool,
min_shape=rhs_min_shape,
opt_shape=rhs_opt_shape,
max_shape=rhs_max_shape,
torch_tensor=torch.randint(0, 2, rhs_opt_shape, dtype=bool),
),
]
self.run_test_with_dynamic_shape(
bitwise_or(),
inputs,
enable_passes=True,
use_dynamo_tracer=True,
use_example_tensors=False,
)

@parameterized.expand(
[
("2d", (5, 3), True),
Expand Down
Loading

0 comments on commit feb4d84

Please sign in to comment.