From a5079ad5d7ebf71b6e2897ce3b758c78bd6f004c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 15 Apr 2024 19:50:05 -0700 Subject: [PATCH] feat: Add dynamic shapes support for torch.compile workflow (#2627) Signed-off-by: Dheeraj Peri --- .github/workflows/build-test.yml | 2 + py/torch_tensorrt/dynamo/_compiler.py | 12 +- py/torch_tensorrt/dynamo/backend/backends.py | 20 ++- .../dynamo/conversion/impl/shape.py | 2 +- py/torch_tensorrt/dynamo/lowering/__init__.py | 2 +- py/torch_tensorrt/dynamo/lowering/_fusers.py | 82 ----------- .../dynamo/lowering/_remove_sym_nodes.py | 30 ++++ tests/py/dynamo/models/test_dyn_models.py | 138 +++++++++++++----- 8 files changed, 156 insertions(+), 132 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/lowering/_fusers.py create mode 100644 py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index c19b4b4b2e..f4d39bd056 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -21,6 +21,7 @@ jobs: os: linux test-infra-repository: pytorch/test-infra test-infra-ref: main + channel: test with-rocm: false with-cpu: false @@ -208,6 +209,7 @@ jobs: ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py + ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py popd tests-py-dynamo-core: diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 41fab3293e..7acebf6c6f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -299,14 +299,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: return False return True - # Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation. + # Check if the module has metadata (shape, dtype). if not contains_metadata(gm): - from torch._inductor.compile_fx import fake_tensor_prop - - torch_inputs = get_torch_inputs(sample_inputs, settings.device) - with torch.no_grad(): - # This fails if the module has data-dependent shape operators. - fake_tensor_prop(gm, torch_inputs) + # TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this. + logger.warning( + "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." + ) # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index bade91c553..66a9729cc0 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -13,6 +13,7 @@ from torch_tensorrt.dynamo.lowering import ( apply_lowering_passes, get_decompositions, + remove_sym_nodes, repair_input_aliasing, ) from torch_tensorrt.dynamo.utils import ( @@ -27,7 +28,7 @@ @td.register_backend(name="tensorrt") # type: ignore[misc] @td.register_backend(name="torch_tensorrt") # type: ignore[misc] def torch_tensorrt_backend( - gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any + gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any ) -> torch.nn.Module: # Set log level at the top of compilation (torch_tensorrt.dynamo) if ( @@ -44,7 +45,7 @@ def torch_tensorrt_backend( @td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc] def aot_torch_tensorrt_aten_backend( - gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any + gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any ) -> torch.nn.Module: settings = parse_dynamo_kwargs(kwargs) return _pretraced_backend(gm, sample_inputs, settings) @@ -52,7 +53,7 @@ def aot_torch_tensorrt_aten_backend( def _pretraced_backend( gm: torch.fx.GraphModule, - sample_inputs: Sequence[torch.Tensor], + sample_inputs: Sequence[Any], settings: CompilationSettings = CompilationSettings(), ) -> torch.fx.GraphModule | Callable[..., Any]: """Helper function to manage translation of traced FX module to TRT engines @@ -74,10 +75,17 @@ def _pretraced_backend( fake_mode, "allow_non_fake_inputs", True ), fake_mode: repair_input_aliasing(gm) + + # Remove sym_int placeholders and inputs + remove_sym_nodes(gm) + torch_inputs = [ + input for input in sample_inputs if isinstance(input, torch.Tensor) + ] + # Invoke AOTAutograd to translate operators to aten gm = aot_export_joint_simple( gm, - sample_inputs, + torch_inputs, trace_joint=False, decompositions=get_decompositions( settings.enable_experimental_decompositions @@ -86,10 +94,10 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - gm = apply_lowering_passes(gm, sample_inputs) + gm = apply_lowering_passes(gm, torch_inputs) torchtrt_inputs = prepare_inputs( - sample_inputs, disable_memory_format_check=True + torch_inputs, disable_memory_format_check=True ) trt_compiled = compile_module( gm, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index eea8f76411..61c1fb99d7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -11,7 +11,6 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( get_positive_dim, get_trt_tensor, - to_numpy, ) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, @@ -98,6 +97,7 @@ def get_shape_with_dynamic_shape( scale_res = scale_layer.get_output(0) length = input_shape.shape[0] + zero_layer = ctx.net.add_constant( input_shape.shape, np.zeros((length), dtype=np.int32) ) diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 7c4e9fdd2d..a89780ded4 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -3,6 +3,6 @@ torch_enabled_decompositions, ) from ._decompositions import get_decompositions # noqa: F401 -from ._fusers import * # noqa: F401 +from ._remove_sym_nodes import remove_sym_nodes from ._repair_input_aliasing import repair_input_aliasing from .passes import apply_lowering_passes diff --git a/py/torch_tensorrt/dynamo/lowering/_fusers.py b/py/torch_tensorrt/dynamo/lowering/_fusers.py deleted file mode 100644 index 720e4ab030..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/_fusers.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from torch_tensorrt.fx.tracer.acc_tracer import acc_ops - - -def check_permute(node: torch.fx.Node) -> bool: - ranks = len(node.meta["tensor_meta"].shape) - permutation = [i % ranks for i in node.kwargs["permutation"]] - allowed_permutation = list(range(ranks)) - allowed_permutation[-1] = ranks - 2 - allowed_permutation[-2] = ranks - 1 - return permutation == allowed_permutation - - -def trt_transposed_matmul( - lhs: torch.Tensor, rhs: torch.Tensor, lhs_transposed: bool, rhs_transposed: bool -) -> torch.Tensor: - if lhs_transposed: - lhs = lhs.transpose(-1, -2) - if rhs_transposed: - rhs = rhs.transpose(-1, -2) - return torch.matmul(lhs, rhs) - - -def fuse_permute_matmul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ - Fuse pattern like permute + matmul if permute is transposing the last two dimension. - """ - for node in gm.graph.nodes: - if node.target == acc_ops.matmul: - lhs, rhs = node.kwargs["input"], node.kwargs["other"] - lhs_transposed = rhs_tranposed = False - skip = False - - if lhs.target == acc_ops.permute and check_permute(lhs): - lhs_transposed = True - lhs = lhs.kwargs["input"] - - if rhs.target == acc_ops.permute and check_permute(rhs): - rhs_tranposed = True - rhs = rhs.kwargs["input"] - - if (not skip) and (lhs_transposed or rhs_tranposed): - with gm.graph.inserting_before(node): - fused_node = gm.graph.call_function( - trt_transposed_matmul, - args=(lhs, rhs, lhs_transposed, rhs_tranposed), - ) - node.replace_all_uses_with(fused_node) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - return gm - - -def trt_transposed_linear( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - return torch.matmul(input.transpose(-1, -2), weight.t()) + bias - - -def fuse_permute_linear(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ - Fuse pattern like permute + linear if permute is transposing the last two dimension. - """ - for node in gm.graph.nodes: - if node.target == acc_ops.linear: - inp = node.kwargs["input"] - if inp.target == acc_ops.permute and check_permute(inp): - inp = inp.kwargs["input"] - weight = node.kwargs["weight"] - bias = node.kwargs["bias"] - with gm.graph.inserting_before(node): - fused_node = gm.graph.call_function( - trt_transposed_linear, args=(inp, weight, bias) - ) - node.replace_all_uses_with(fused_node) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - return gm diff --git a/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py b/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py new file mode 100644 index 0000000000..e85117a423 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py @@ -0,0 +1,30 @@ +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Remove sym_int placeholders which get inserted due to torch.compile's + dynamic=True behavior + """ + # Extract SymInt placeholder Tensors + placeholders = [ + node + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.SymInt) + ) + ] + + for node in placeholders: + gm.graph.erase_node(node) + + gm.graph.lint() + gm.recompile() + logger.debug(f"Removed SymInt placeholders:\n{gm.graph}") + + return gm diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 822ee468a9..5a33285fdb 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -27,10 +27,22 @@ def forward(self, x): return out model = MyModule().eval().cuda() - input = torch.randn((1, 3, 224, 224)).to("cuda") compile_spec = { - "inputs": [ + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + } + if ir == "torch_compile": + input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda") + torch._dynamo.mark_dynamic(input_bs4, 0, min=1, max=8) + # Compile the model + trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) + trt_model(input_bs4) + elif ir == "dynamo": + compile_spec["inputs"] = [ torchtrt.Input( min_shape=(1, 3, 224, 224), opt_shape=(4, 3, 224, 224), @@ -38,22 +50,15 @@ def forward(self, x): dtype=torch.float32, name="x", ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 1, - } + ] + trt_model = torchtrt.compile(model, **compile_spec) - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda") + cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6)) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_dyn_full_compile model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env torch._dynamo.reset() @@ -80,32 +85,39 @@ def forward(self, x): return out model = MyModule().eval().cuda() - input = torch.randn((1, 3, 224, 224)).to("cuda") compile_spec = { - "inputs": [ - torchtrt.Input( - min_shape=(1, 3, 224, 224), - opt_shape=(4, 3, 224, 224), - max_shape=(8, 3, 224, 224), - dtype=torch.float32, - name="x", - ) - ], "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, - "optimization_level": 1, "torch_executed_ops": {"torch.ops.aten.abs.default"}, "min_block_size": 1, } - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + if ir == "torch_compile": + input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda") + torch._dynamo.mark_dynamic(input_bs4, 0, min=1, max=8) + # Compile the model + trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) + trt_model(input_bs4) + elif ir == "dynamo": + compile_spec["inputs"] = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + name="x", + ) + ] + trt_model = torchtrt.compile(model, **compile_spec) + + input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda") + cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6)) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_base_dynamic_fallback model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env @@ -131,10 +143,23 @@ def forward(self, x): return y model = MyModule().eval().cuda() - input = torch.randn((6, 3, 4)).to("cuda") compile_spec = { - "inputs": [ + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + } + + if ir == "torch_compile": + input_bs4 = torch.randn((4, 3, 4)).to("cuda") + torch._dynamo.mark_dynamic(input_bs4, 0, min=1, max=8) + # Compile the model + trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) + trt_model(input_bs4) + elif ir == "dynamo": + compile_spec["inputs"] = [ torchtrt.Input( min_shape=(1, 3, 4), opt_shape=(4, 3, 4), @@ -142,20 +167,63 @@ def forward(self, x): dtype=torch.float32, name="x", ) - ], + ] + trt_model = torchtrt.compile(model, **compile_spec) + + input_bs6 = torch.randn((6, 3, 4)).to("cuda") + cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_view model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + + +@pytest.mark.unit +def test_resnet_dynamic(ir): + """ + Tests the Resnet18 model (which is fully convertible) with dynamic shapes + """ + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().to("cuda") + + compile_spec = { "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, - "optimization_level": 1, "min_block_size": 1, } - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)) + if ir == "torch_compile": + input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda") + torch._dynamo.mark_dynamic(input_bs2, 0, min=1, max=8) + # Compile the model + trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) + trt_model(input_bs2) + elif ir == "dynamo": + compile_spec["inputs"] = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + name="x", + ) + ] + trt_model = torchtrt.compile(model, **compile_spec) + + input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda") + cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6)) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env