diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 88e66b0f3c..a89d7bbd2c 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -765,7 +765,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, + skip_fusion=(num_supported_ops == total_ops), ) + except torch.fx.passes.splitter_base.FxNetSplitterInternalError: logger.error( "Partitioning failed on the subgraph with fast partition. See trace above. " diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 0e9077cdcb..429de3ffbb 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -111,6 +111,7 @@ def __init__( min_block_size: int = MIN_BLOCK_SIZE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, return_tuple: bool = False, + skip_fusion: bool = False, ): """ Preprocesses graph before splitting: @@ -127,6 +128,7 @@ def __init__( self.settings = _SplitterSettingBase( min_acc_module_size=min_block_size, allow_non_tensor=True, + skip_fusion=skip_fusion, ) self.operator_support = operator_support @@ -252,6 +254,7 @@ def partition( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, + skip_fusion: bool = False, ) -> Tuple[torch.fx.GraphModule, OpSupportTester]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -262,6 +265,7 @@ def partition( min_block_size: Minimum number of operators per TRT-Engine Block torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Require that all computational operators be run in TRT + skip_fusion: Skip fusions found by FxNetAccFusionsFinder Returns: torch.fx.GraphModule, OpSupportTester """ @@ -277,6 +281,7 @@ def partition( supported_ops, min_block_size=min_block_size, require_full_compilation=require_full_compilation, + skip_fusion=skip_fusion, ) partitioned_graph = partitioner.partition_graph()