From feb4d84ffb399f920eb524818ed8c5106494d02a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 2 Jul 2024 15:38:30 -0700 Subject: [PATCH] Add dynamic shape support for bitwise_and/or/xor/not, exp/expm1/recip/log/log2/log10 (#2973) --- .../dynamo/conversion/aten_ops_converters.py | 57 +++++++++---- .../conversion/test_bitwise_and_aten.py | 84 +++++++++++++++++++ .../conversion/test_bitwise_not_aten.py | 29 +++++++ .../dynamo/conversion/test_bitwise_or_aten.py | 45 ++++++++++ .../conversion/test_bitwise_xor_aten.py | 45 ++++++++++ tests/py/dynamo/conversion/test_exp_aten.py | 27 ++++++ tests/py/dynamo/conversion/test_expm1_aten.py | 27 ++++++ tests/py/dynamo/conversion/test_log10.py | 27 ++++++ tests/py/dynamo/conversion/test_log2.py | 27 ++++++ tests/py/dynamo/conversion/test_log_aten.py | 27 ++++++ tests/py/dynamo/conversion/test_recip_aten.py | 27 ++++++ 11 files changed, 405 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 930923d23c..8fe67d6507 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1324,7 +1324,7 @@ def aten_ops_mean( ) -@dynamo_tensorrt_converter(torch.ops.aten.exp.default) +@dynamo_tensorrt_converter(torch.ops.aten.exp.default, supports_dynamic_shapes=True) def aten_ops_exp( ctx: ConversionContext, target: Target, @@ -1341,7 +1341,7 @@ def aten_ops_exp( ) -@dynamo_tensorrt_converter(torch.ops.aten.expm1.default) +@dynamo_tensorrt_converter(torch.ops.aten.expm1.default, supports_dynamic_shapes=True) def aten_ops_expm1( ctx: ConversionContext, target: Target, @@ -1358,7 +1358,7 @@ def aten_ops_expm1( ) -@dynamo_tensorrt_converter(torch.ops.aten.log.default) +@dynamo_tensorrt_converter(torch.ops.aten.log.default, supports_dynamic_shapes=True) def aten_ops_log( ctx: ConversionContext, target: Target, @@ -1375,7 +1375,7 @@ def aten_ops_log( ) -@dynamo_tensorrt_converter(torch.ops.aten.log2.default) +@dynamo_tensorrt_converter(torch.ops.aten.log2.default, supports_dynamic_shapes=True) def aten_ops_log2( ctx: ConversionContext, target: Target, @@ -1392,7 +1392,7 @@ def aten_ops_log2( ) -@dynamo_tensorrt_converter(torch.ops.aten.log10.default) +@dynamo_tensorrt_converter(torch.ops.aten.log10.default, supports_dynamic_shapes=True) def aten_ops_log10( ctx: ConversionContext, target: Target, @@ -1443,7 +1443,9 @@ def aten_ops_sqrt( ) -@dynamo_tensorrt_converter(torch.ops.aten.reciprocal.default) +@dynamo_tensorrt_converter( + torch.ops.aten.reciprocal.default, supports_dynamic_shapes=True +) def aten_ops_recip( ctx: ConversionContext, target: Target, @@ -2054,7 +2056,9 @@ def aten_ops_logical_and( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) +@dynamo_tensorrt_converter( + torch.ops.aten.logical_or.default, supports_dynamic_shapes=True +) def aten_ops_logical_or( ctx: ConversionContext, target: Target, @@ -2072,7 +2076,9 @@ def aten_ops_logical_or( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) +@dynamo_tensorrt_converter( + torch.ops.aten.logical_xor.default, supports_dynamic_shapes=True +) def aten_ops_logical_xor( ctx: ConversionContext, target: Target, @@ -2108,7 +2114,6 @@ def bitwise_type_validator(node: Node) -> bool: torch.ops.aten.bitwise_or.Scalar_Tensor, torch.ops.aten.bitwise_xor.Scalar_Tensor, ] - if node.target in tensor_targets: lhs_val = node.args[0] rhs_val = node.args[1] @@ -2139,14 +2144,19 @@ def bitwise_type_validator(node: Node) -> bool: @dynamo_tensorrt_converter( - torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator + torch.ops.aten.bitwise_and.Tensor, + capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator + torch.ops.aten.bitwise_and.Scalar, + capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( torch.ops.aten.bitwise_and.Scalar_Tensor, capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) def aten_ops_bitwise_and( ctx: ConversionContext, @@ -2166,13 +2176,19 @@ def aten_ops_bitwise_and( @dynamo_tensorrt_converter( - torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator + torch.ops.aten.bitwise_or.Tensor, + capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator + torch.ops.aten.bitwise_or.Scalar, + capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator + torch.ops.aten.bitwise_or.Scalar_Tensor, + capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) def aten_ops_bitwise_or( ctx: ConversionContext, @@ -2192,14 +2208,19 @@ def aten_ops_bitwise_or( @dynamo_tensorrt_converter( - torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator + torch.ops.aten.bitwise_xor.Tensor, + capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator + torch.ops.aten.bitwise_xor.Scalar, + capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( torch.ops.aten.bitwise_xor.Scalar_Tensor, capability_validator=bitwise_type_validator, + supports_dynamic_shapes=True, ) def aten_ops_bitwise_xor( ctx: ConversionContext, @@ -2230,7 +2251,9 @@ def bitwise_not_type_validator(node: Node) -> bool: @dynamo_tensorrt_converter( - torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator + torch.ops.aten.bitwise_not.default, + capability_validator=bitwise_not_type_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py index 22564346ef..8bd0415bdc 100644 --- a/tests/py/dynamo/conversion/test_bitwise_and_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -1,7 +1,10 @@ import torch import torch.nn as nn +import torch_tensorrt from parameterized import parameterized +from torch.export import Dim from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -32,6 +35,50 @@ def forward(self, lhs_val, rhs_val): use_dynamo_tracer=True, ) + @parameterized.expand( + [ + ("2d-2d", (2, 3), (3, 3), (5, 3), (2, 3), (3, 3), (5, 3)), + ("3d-3d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (1, 2, 2), (1, 3, 2), (1, 4, 2)), + ] + ) + def test_bitwise_and_tensor_dynamic_shape( + self, + _, + lhs_min_shape, + lhs_opt_shape, + lhs_max_shape, + rhs_min_shape, + rhs_opt_shape, + rhs_max_shape, + ): + class bitwise_and(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val) + + inputs = [ + Input( + dtype=torch.bool, + min_shape=lhs_min_shape, + opt_shape=lhs_opt_shape, + max_shape=lhs_max_shape, + torch_tensor=torch.randint(0, 2, lhs_opt_shape, dtype=bool), + ), + Input( + dtype=torch.bool, + min_shape=rhs_min_shape, + opt_shape=rhs_opt_shape, + max_shape=rhs_max_shape, + torch_tensor=torch.randint(0, 2, rhs_opt_shape, dtype=bool), + ), + ] + self.run_test_with_dynamic_shape( + bitwise_and(), + inputs, + enable_passes=True, + use_dynamo_tracer=True, + use_example_tensors=False, + ) + @parameterized.expand( [ ("2d", (5, 3), True), @@ -74,6 +121,43 @@ def forward(self, tensor): use_dynamo_tracer=True, ) + # this test case is to test the bitwise_and with different ranks + # it cannot use the normal test_with_dynamic_shape due to the + # torch_tensorrt.dynamo.trace doesn't automatically handle it + # hence has to manually export the graph and run the test. + def test_bitwise_and_dynamic_shape_with_different_ranks(self): + class bitwise_and(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val) + + dyn_dim = Dim("dyn_dim", min=2, max=6) + inputs = ( + torch.randint(0, 2, (2, 4, 2), dtype=bool), + torch.randint(0, 2, (4, 2), dtype=bool), + ) + mod = bitwise_and() + fx_mod = torch.export.export( + mod, inputs, dynamic_shapes=({1: dyn_dim}, {0: dyn_dim}) + ) + trt_mod = torch_tensorrt.dynamo.compile( + fx_mod, inputs=inputs, enable_precisions={torch.bool}, min_block_size=1 + ) + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + ref_outputs = mod(*cuda_inputs) + outputs = trt_mod(*cuda_inputs) + for out, ref in zip(outputs, ref_outputs): + torch.testing.assert_close( + out, + ref, + rtol=0.001, + atol=0.001, + equal_nan=True, + check_dtype=True, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_not_aten.py b/tests/py/dynamo/conversion/test_bitwise_not_aten.py index 33d8629aff..2ff8510a5c 100644 --- a/tests/py/dynamo/conversion/test_bitwise_not_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_not_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -28,6 +29,34 @@ def forward(self, val): use_dynamo_tracer=True, ) + @parameterized.expand( + [ + ("2d", (2, 3), (5, 3), (6, 4)), + ("3d", (2, 3, 2), (3, 4, 2), (5, 4, 2)), + ] + ) + def test_bitwise_not_tensor_dynamic_shape(self, _, min_shape, opt_shape, max_shape): + class bitwise_not(nn.Module): + def forward(self, val): + return torch.ops.aten.bitwise_not.default(val) + + inputs = [ + Input( + dtype=torch.bool, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + torch_tensor=torch.randint(0, 2, opt_shape, dtype=bool), + ) + ] + self.run_test_with_dynamic_shape( + bitwise_not(), + inputs, + enable_passes=True, + use_dynamo_tracer=True, + use_example_tensors=False, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_or_aten.py b/tests/py/dynamo/conversion/test_bitwise_or_aten.py index a94d27e3fd..0d571f7bc7 100644 --- a/tests/py/dynamo/conversion/test_bitwise_or_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_or_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -32,6 +33,50 @@ def forward(self, lhs_val, rhs_val): use_dynamo_tracer=True, ) + @parameterized.expand( + [ + ("2d-2d", (2, 3), (3, 3), (5, 3), (2, 3), (3, 3), (5, 3)), + ("3d-3d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (1, 2, 2), (1, 3, 2), (1, 4, 2)), + ] + ) + def test_bitwise_or_tensor_dynamic_shape( + self, + _, + lhs_min_shape, + lhs_opt_shape, + lhs_max_shape, + rhs_min_shape, + rhs_opt_shape, + rhs_max_shape, + ): + class bitwise_or(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val) + + inputs = [ + Input( + dtype=torch.bool, + min_shape=lhs_min_shape, + opt_shape=lhs_opt_shape, + max_shape=lhs_max_shape, + torch_tensor=torch.randint(0, 2, lhs_opt_shape, dtype=bool), + ), + Input( + dtype=torch.bool, + min_shape=rhs_min_shape, + opt_shape=rhs_opt_shape, + max_shape=rhs_max_shape, + torch_tensor=torch.randint(0, 2, rhs_opt_shape, dtype=bool), + ), + ] + self.run_test_with_dynamic_shape( + bitwise_or(), + inputs, + enable_passes=True, + use_dynamo_tracer=True, + use_example_tensors=False, + ) + @parameterized.expand( [ ("2d", (5, 3), True), diff --git a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py index b07399ecf7..768c277e54 100644 --- a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -32,6 +33,50 @@ def forward(self, lhs_val, rhs_val): use_dynamo_tracer=True, ) + @parameterized.expand( + [ + ("2d-2d", (2, 3), (3, 3), (5, 3), (2, 3), (3, 3), (5, 3)), + ("3d-3d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (1, 2, 2), (1, 3, 2), (1, 4, 2)), + ] + ) + def test_bitwise_xor_tensor_dynamic_shape( + self, + _, + lhs_min_shape, + lhs_opt_shape, + lhs_max_shape, + rhs_min_shape, + rhs_opt_shape, + rhs_max_shape, + ): + class bitwise_xor(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val) + + inputs = [ + Input( + dtype=torch.bool, + min_shape=lhs_min_shape, + opt_shape=lhs_opt_shape, + max_shape=lhs_max_shape, + torch_tensor=torch.randint(0, 2, lhs_opt_shape, dtype=bool), + ), + Input( + dtype=torch.bool, + min_shape=rhs_min_shape, + opt_shape=rhs_opt_shape, + max_shape=rhs_max_shape, + torch_tensor=torch.randint(0, 2, rhs_opt_shape, dtype=bool), + ), + ] + self.run_test_with_dynamic_shape( + bitwise_xor(), + inputs, + enable_passes=True, + use_dynamo_tracer=True, + use_example_tensors=False, + ) + @parameterized.expand( [ ("2d", (5, 3), True), diff --git a/tests/py/dynamo/conversion/test_exp_aten.py b/tests/py/dynamo/conversion/test_exp_aten.py index ac1c5dfbcb..6cecac476f 100644 --- a/tests/py/dynamo/conversion/test_exp_aten.py +++ b/tests/py/dynamo/conversion/test_exp_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -26,6 +27,32 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((1,), (3,), (5,)), + ((1, 20), (2, 20), (3, 20)), + ((2, 3, 4), (3, 4, 5), (4, 5, 6)), + ((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)), + ] + ) + def test_exp_float_dynamic_shape(self, min_shape, opt_shape, max_shape): + class exp(nn.Module): + def forward(self, input): + return torch.ops.aten.exp.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + exp(), + input_specs, + ) + @parameterized.expand( [ ((10,), torch.int, 0, 5), diff --git a/tests/py/dynamo/conversion/test_expm1_aten.py b/tests/py/dynamo/conversion/test_expm1_aten.py index e695a27475..dd256bbc74 100644 --- a/tests/py/dynamo/conversion/test_expm1_aten.py +++ b/tests/py/dynamo/conversion/test_expm1_aten.py @@ -4,6 +4,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -28,6 +29,32 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((1,), (3,), (5,)), + ((1, 20), (2, 20), (3, 20)), + ((2, 3, 4), (3, 4, 5), (4, 5, 6)), + ((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)), + ] + ) + def test_expm1_float_dynamic_shape(self, min_shape, opt_shape, max_shape): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + expm1(), + input_specs, + ) + @parameterized.expand( [ (torch.full((1, 20), exp(1), dtype=torch.float),), diff --git a/tests/py/dynamo/conversion/test_log10.py b/tests/py/dynamo/conversion/test_log10.py index 9094f6b278..68468da45c 100644 --- a/tests/py/dynamo/conversion/test_log10.py +++ b/tests/py/dynamo/conversion/test_log10.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -26,6 +27,32 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((1,), (3,), (5,)), + ((1, 20), (2, 20), (3, 20)), + ((2, 3, 4), (3, 4, 5), (4, 5, 6)), + ((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)), + ] + ) + def test_log10_float_dynamic_shape(self, min_shape, opt_shape, max_shape): + class log10(nn.Module): + def forward(self, input): + return torch.ops.aten.log10.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + log10(), + input_specs, + ) + @parameterized.expand( [ ((10,), torch.int, 0, 5), diff --git a/tests/py/dynamo/conversion/test_log2.py b/tests/py/dynamo/conversion/test_log2.py index c641423da7..225f2938fd 100644 --- a/tests/py/dynamo/conversion/test_log2.py +++ b/tests/py/dynamo/conversion/test_log2.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -26,6 +27,32 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((1,), (3,), (5,)), + ((1, 20), (2, 20), (3, 20)), + ((2, 3, 4), (3, 4, 5), (4, 5, 6)), + ((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)), + ] + ) + def test_log_float_dynamic_shape(self, min_shape, opt_shape, max_shape): + class log2(nn.Module): + def forward(self, input): + return torch.ops.aten.log2.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + log2(), + input_specs, + ) + @parameterized.expand( [ ((10,), torch.int, 0, 5), diff --git a/tests/py/dynamo/conversion/test_log_aten.py b/tests/py/dynamo/conversion/test_log_aten.py index 662b7ab99d..177e0707fe 100644 --- a/tests/py/dynamo/conversion/test_log_aten.py +++ b/tests/py/dynamo/conversion/test_log_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -26,6 +27,32 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((1,), (3,), (5,)), + ((1, 20), (2, 20), (3, 20)), + ((2, 3, 4), (3, 4, 5), (4, 5, 6)), + ((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)), + ] + ) + def test_log_float_dynamic_shape(self, min_shape, opt_shape, max_shape): + class log(nn.Module): + def forward(self, input): + return torch.ops.aten.log.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + log(), + input_specs, + ) + @parameterized.expand( [ ((10,), torch.int, 0, 5), diff --git a/tests/py/dynamo/conversion/test_recip_aten.py b/tests/py/dynamo/conversion/test_recip_aten.py index c34fcb2f08..1aed9d2182 100644 --- a/tests/py/dynamo/conversion/test_recip_aten.py +++ b/tests/py/dynamo/conversion/test_recip_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -26,6 +27,32 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((1,), (3,), (5,)), + ((1, 20), (2, 20), (3, 20)), + ((2, 3, 4), (3, 4, 5), (4, 5, 6)), + ((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)), + ] + ) + def test_recip_float_dynamic_shape(self, min_shape, opt_shape, max_shape): + class recip(nn.Module): + def forward(self, input): + return torch.ops.aten.reciprocal.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + recip(), + input_specs, + ) + @parameterized.expand( [ ((10,), torch.int, 0, 5),