Skip to content

Commit

Permalink
add validator to keep static inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Feb 17, 2024
1 parent 8c41f14 commit 65bbf45
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,10 +2184,27 @@ def aten_ops_avg_pool(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
def adaptive_pool_static_input_validator(pool_node: Node) -> bool:
output_size = args_bounds_check(pool_node.args, 1)
return all([x > 0 for x in output_size])


@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool2d.default,
capability_validator=adaptive_pool_static_input_validator,
)
@dynamo_tensorrt_converter(
torch.ops.aten._adaptive_avg_pool2d.default,
capability_validator=adaptive_pool_static_input_validator,
)
@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool3d.default,
capability_validator=adaptive_pool_static_input_validator,
)
@dynamo_tensorrt_converter(
torch.ops.aten._adaptive_avg_pool3d.default,
capability_validator=adaptive_pool_static_input_validator,
)
def aten_ops_adaptive_avg_poolNd(
ctx: ConversionContext,
target: Target,
Expand Down

0 comments on commit 65bbf45

Please sign in to comment.