Skip to content

Commit

Permalink
feat: Add dynamic shapes support for torch.compile workflow (#2627)
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 authored and zewenli98 committed Apr 26, 2024
1 parent d7e47be commit a5079ad
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 132 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
os: linux
test-infra-repository: pytorch/test-infra
test-infra-ref: main
channel: test
with-rocm: false
with-cpu: false

Expand Down Expand Up @@ -208,6 +209,7 @@ jobs:
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
popd
tests-py-dynamo-core:
Expand Down
12 changes: 5 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
return False
return True

# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
# Check if the module has metadata (shape, dtype).
if not contains_metadata(gm):
from torch._inductor.compile_fx import fake_tensor_prop

torch_inputs = get_torch_inputs(sample_inputs, settings.device)
with torch.no_grad():
# This fails if the module has data-dependent shape operators.
fake_tensor_prop(gm, torch_inputs)
# TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
logger.warning(
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False
Expand Down
20 changes: 14 additions & 6 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch_tensorrt.dynamo.lowering import (
apply_lowering_passes,
get_decompositions,
remove_sym_nodes,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.utils import (
Expand All @@ -27,7 +28,7 @@
@td.register_backend(name="tensorrt") # type: ignore[misc]
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
def torch_tensorrt_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
# Set log level at the top of compilation (torch_tensorrt.dynamo)
if (
Expand All @@ -44,15 +45,15 @@ def torch_tensorrt_backend(

@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
settings = parse_dynamo_kwargs(kwargs)
return _pretraced_backend(gm, sample_inputs, settings)


def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
sample_inputs: Sequence[Any],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule | Callable[..., Any]:
"""Helper function to manage translation of traced FX module to TRT engines
Expand All @@ -74,10 +75,17 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Remove sym_int placeholders and inputs
remove_sym_nodes(gm)
torch_inputs = [
input for input in sample_inputs if isinstance(input, torch.Tensor)
]

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
sample_inputs,
torch_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
Expand All @@ -86,10 +94,10 @@ def _pretraced_backend(

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = apply_lowering_passes(gm, sample_inputs)
gm = apply_lowering_passes(gm, torch_inputs)

torchtrt_inputs = prepare_inputs(
sample_inputs, disable_memory_format_check=True
torch_inputs, disable_memory_format_check=True
)
trt_compiled = compile_module(
gm,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
get_trt_tensor,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
Expand Down Expand Up @@ -98,6 +97,7 @@ def get_shape_with_dynamic_shape(
scale_res = scale_layer.get_output(0)

length = input_shape.shape[0]

zero_layer = ctx.net.add_constant(
input_shape.shape, np.zeros((length), dtype=np.int32)
)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
torch_enabled_decompositions,
)
from ._decompositions import get_decompositions # noqa: F401
from ._fusers import * # noqa: F401
from ._remove_sym_nodes import remove_sym_nodes
from ._repair_input_aliasing import repair_input_aliasing
from .passes import apply_lowering_passes
82 changes: 0 additions & 82 deletions py/torch_tensorrt/dynamo/lowering/_fusers.py

This file was deleted.

30 changes: 30 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging

import torch

logger = logging.getLogger(__name__)


def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Remove sym_int placeholders which get inserted due to torch.compile's
dynamic=True behavior
"""
# Extract SymInt placeholder Tensors
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.SymInt)
)
]

for node in placeholders:
gm.graph.erase_node(node)

gm.graph.lint()
gm.recompile()
logger.debug(f"Removed SymInt placeholders:\n{gm.graph}")

return gm
Loading

0 comments on commit a5079ad

Please sign in to comment.