From 65bbf459a0b9629eb02337eae5cfbb7e1fd3fc84 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 16 Feb 2024 16:17:21 -0800 Subject: [PATCH] add validator to keep static inputs --- .../dynamo/conversion/aten_ops_converters.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d0f25c4624..8c3d5ce088 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2184,10 +2184,27 @@ def aten_ops_avg_pool( ) -@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default) -@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default) -@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default) -@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default) +def adaptive_pool_static_input_validator(pool_node: Node) -> bool: + output_size = args_bounds_check(pool_node.args, 1) + return all([x > 0 for x in output_size]) + + +@dynamo_tensorrt_converter( + torch.ops.aten.adaptive_avg_pool2d.default, + capability_validator=adaptive_pool_static_input_validator, +) +@dynamo_tensorrt_converter( + torch.ops.aten._adaptive_avg_pool2d.default, + capability_validator=adaptive_pool_static_input_validator, +) +@dynamo_tensorrt_converter( + torch.ops.aten.adaptive_avg_pool3d.default, + capability_validator=adaptive_pool_static_input_validator, +) +@dynamo_tensorrt_converter( + torch.ops.aten._adaptive_avg_pool3d.default, + capability_validator=adaptive_pool_static_input_validator, +) def aten_ops_adaptive_avg_poolNd( ctx: ConversionContext, target: Target,