Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add dynamic shapes support for torch.compile workflow #2627

Merged
merged 63 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
d06c74a
chore: Switch to new export apis
peri044 Oct 3, 2023
47e0997
chore: rebase with main
peri044 Oct 9, 2023
fd29fe0
Merge branch 'main' into export_2.2
peri044 Oct 11, 2023
ad3b031
feat: Add support for dynamic shapes and remove constraints API
peri044 Oct 19, 2023
1582b72
chore: add dynamic shape support for certain converters
peri044 Oct 23, 2023
4d01545
chore: minor updates
peri044 Oct 25, 2023
6731a57
chore: updates
peri044 Oct 26, 2023
a8a194b
chore: rebase with main
peri044 Nov 15, 2023
0b60aae
chore: add sym int converter
peri044 Nov 15, 2023
634612f
feat: Replace the existing shape propagation with symbolic shape prop…
peri044 Nov 16, 2023
93edba4
chore: fix imports
peri044 Nov 16, 2023
7ad9272
chore: fix imports
peri044 Nov 16, 2023
f444d54
chore: updates
peri044 Nov 21, 2023
6e5c582
chore: change device calls
peri044 Nov 28, 2023
83791f8
chore: fix metadata check
peri044 Dec 5, 2023
8375996
chore: rebase with main
peri044 Dec 15, 2023
aba91fa
Merge branch 'main' into dyn_2.2
peri044 Dec 22, 2023
16394d9
chore: minor fixes
peri044 Jan 7, 2024
b9a7ccd
chore: Add sym_size converter tests
peri044 Jan 8, 2024
15cc643
chore: Update test utilities
peri044 Jan 8, 2024
5234d74
chore: add testcase for sym_size.int
peri044 Jan 8, 2024
fd2dae1
Merge branch 'main' into dyn_2.2
peri044 Jan 26, 2024
51e8bb7
chore: revert output type change
peri044 Jan 26, 2024
19c3fad
chore: add update_metadata utility
peri044 Jan 27, 2024
ed48551
chore: change debug to warning if the graph does not have metadata
peri044 Jan 27, 2024
18b7e11
feat: add lowering passes to support dynamic shapes for torch.compile
peri044 Jan 30, 2024
3a39d27
chore: add test case
peri044 Jan 30, 2024
abb2677
chore: add view test case
peri044 Feb 2, 2024
9aff04b
chore: gpt2 changes + linting
peri044 Feb 7, 2024
440fcd5
chore: gpt2 changes + linting
peri044 Feb 7, 2024
a2d38f3
chore: rebase with main
peri044 Feb 7, 2024
002db3c
chore: add fallback option if val is missing in metadata
peri044 Feb 7, 2024
00cd17b
chore: tmp changes
peri044 Feb 13, 2024
6ac70cd
chore: tmp changes
peri044 Feb 13, 2024
b827070
Merge branch 'main' into dyn_2.2
peri044 Feb 16, 2024
8f9bca0
Merge branch 'main' into dyn_2.2
peri044 Feb 21, 2024
4399d57
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Feb 21, 2024
39615a2
chore: fixes
peri044 Feb 26, 2024
cd86660
feat: Add save API for torch-trt compiled models
peri044 Mar 14, 2024
3ece71b
chore: resolve merge conflicts
peri044 Mar 15, 2024
1fa1771
Merge branch 'main' into dyn_2.2
peri044 Mar 15, 2024
febf05b
Merge branch 'save' into dyn_2.2
peri044 Mar 15, 2024
eab0dba
chore: Fix save failures
peri044 Mar 18, 2024
b191d62
chore: update to 2.3 rc build
peri044 Mar 18, 2024
380477b
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Mar 19, 2024
5f34d4f
chore: minor fixes
peri044 Mar 19, 2024
ce606fe
chore: rebase with release/2.3 branch
peri044 Mar 19, 2024
8674a3c
chore: minor fixes
peri044 Mar 19, 2024
f4e8fe9
chore: remove duplicate bert test case
peri044 Mar 20, 2024
4ae6ab9
chore: remove comments
peri044 Mar 20, 2024
c14f28d
Merge branch 'save' into dyn_2.2
peri044 Mar 20, 2024
3295c02
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Mar 20, 2024
4188173
chore: rebase with release/2.3
peri044 Apr 2, 2024
f6b758e
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 2, 2024
78f7eb5
chore: updates
peri044 Apr 2, 2024
fe13c2a
chore: Update mypy type for sample_inputs
peri044 Apr 2, 2024
e9b649d
chore: revert changes
peri044 Apr 5, 2024
03ecc61
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 5, 2024
978c039
Merge branch 'release/2.3' into dyn_2.2
peri044 Apr 5, 2024
ccb88c8
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 5, 2024
3cccf8a
chore: rebase
peri044 Apr 15, 2024
2d24686
chore: update to use test channel
peri044 Apr 15, 2024
8e36525
chore: updates
peri044 Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union

import torch
import torch_tensorrt
from torch.export import ExportedProgram
from torch.fx.node import Target
from torch_tensorrt import _enums
Expand Down Expand Up @@ -66,8 +67,6 @@
to_torch_tensorrt_device,
)

import torch_tensorrt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -185,10 +184,10 @@ def compile(
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
torch_inputs = get_torch_inputs(inputs, device)
gm = apply_lowering_passes(gm, torch_inputs)

logger.debug("Lowered Input graph: " + str(gm.graph))

enabled_precisions = set(enabled_precisions)
Expand Down Expand Up @@ -302,6 +301,24 @@ def compile_module(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

def contains_metadata(gm: torch.fx.GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op != "output" and (not node.meta) and "val" not in node.meta:
logger.warning(
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
)
return False
return True

# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
if not contains_metadata(gm):
from torch._inductor.compile_fx import fake_tensor_prop

torch_inputs = get_torch_inputs(sample_inputs, settings.device)
with torch.no_grad():
# This fails if the module has data-dependent shape operators.
fake_tensor_prop(gm, torch_inputs)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

Expand Down Expand Up @@ -360,12 +377,7 @@ def compile_module(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module,
submodule,
sample_inputs,
to_torch_device(settings.device),
)
submodule_inputs = partitioning.construct_submodule_inputs(submodule)

logger.debug(
"Submodule name: %s\n Input shapes: %s\n %s",
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
Expand Down
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,17 @@ def aten_ops_sigmoid(
)


@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
def aten_ops_symsize_int(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])


def index_dtype_validator(node: Node) -> bool:
index = node.args[1]
for ind in index:
Expand All @@ -402,6 +413,7 @@ def index_dtype_validator(node: Node) -> bool:
return True


@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
@dynamo_tensorrt_converter(
torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
)
Expand Down
33 changes: 32 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,45 @@
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
get_trt_tensor,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def shape(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: int,
) -> TRTTensor:
"""
This is the general shape layer implementation in TensorRT.
sym_size.int ops map to addShape layer in TensorRT and returns
the dynamic shape of the tensor optionally taking in a dim argument.
"""
shape_layer = ctx.net.add_shape(input_val)
input_shape = shape_layer.get_output(0)
set_layer_name(shape_layer, target, name + "_shape", source_ir)

n_dims = len(input_val.shape)
dim = get_positive_dim(dim, n_dims)
dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0)
set_layer_name(gather_layer, target, name + "_gather", source_ir)
input_shape = gather_layer.get_output(0)

return input_shape


def get_shape_with_dynamic_shape(
ctx: ConversionContext,
target: Target,
Expand Down
20 changes: 18 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

Expand All @@ -17,7 +17,23 @@ def reshape(
shape: Sequence[int],
) -> TRTTensor:
layer = ctx.net.add_shuffle(input)
layer.reshape_dims = tuple(shape)
if all(isinstance(s, int) for s in shape):
layer.reshape_dims = tuple(shape)
else:
# Convert all the dimensions to trt Tensors.
trt_shape = []

for i, s in enumerate(shape):
if isinstance(s, TRTTensor):
trt_shape.append(s)
else:
a = get_trt_tensor(ctx, s, f"{name}_{i}")
trt_shape.append(a)
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
shape_layer.axis = 0
shape_layer.name = f"{name}_output_shape"
layer.set_input(1, shape_layer.get_output(0))

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)

Expand Down
19 changes: 18 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, Dict, List

import torch

Expand Down Expand Up @@ -29,3 +29,20 @@ def get_tensor_placeholders(
]

return placeholders


def update_metadata(
gm: torch.fx.GraphModule, target_op: Any, metadata: Dict[int, torch._ops.OpOverload]
) -> None:
"""
Given a graph and a node which has target_op in the graph,
a) If the node has metadata, store it in the map
b) If the node does not have metadata, retrieve it from the map
and assign to the node.
"""
for idx, node in enumerate(gm.graph.nodes):
if node.target == target_op:
if idx not in metadata and node.meta:
metadata[idx] = node.meta
elif idx in metadata and not node.meta:
node.meta = metadata[idx]
37 changes: 17 additions & 20 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Callable, List, Sequence, Tuple
from typing import Dict, List, Sequence

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
update_metadata,
)

logger = logging.getLogger(__name__)
Expand All @@ -13,29 +14,25 @@ def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]
):
"""Constructs the original and replacement functions for view"""
orig_op = torch.ops.aten.view.default
replacement_op = torch.ops.aten.reshape.default

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
return orig_op(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)
return replacement_op(input, shape)

return orig, replacement
# Store metadata of the orig_op and copy it to the replacement op
meta_map: Dict[int, torch._ops.OpOverload] = {}
update_metadata(gm, orig_op, meta_map)

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

update_metadata(gm, replacement_op, meta_map)

return gm
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
from .common import (
construct_submodule_inputs,
get_graph_converter_support,
get_submod_inputs,
run_shape_analysis,
)
74 changes: 74 additions & 0 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,80 @@
logger = logging.getLogger(__name__)


def contains_sym_int(tensor: torch.Tensor) -> bool:
"""
Returns true if the given tensor has symbolic shape.
"""
for dim in tensor:
if isinstance(dim, torch.SymInt):
return True
return False


def construct_dynamic_input(input: Any) -> Input:
"""
Constructs a torch_tensorrt.Input based on a symbolic input
Args:
input: A symbolic shape tensor (which can have a mix of SymInt nodes and static values)
Returns:
A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
"""
input_sym_shape = input.size()
min_shape = []
opt_shape = []
max_shape = []
for dim in input_sym_shape:
if isinstance(dim, torch.SymInt):
node = dim.node
expr = node.expr
shape_env = node.shape_env
var_range = shape_env.var_to_range.get(expr, None)
var_val = shape_env.var_to_val.get(expr, None)
assert var_range, var_val
# Torchdynamo 0/1 specialization outlier
if var_range.lower == 2:
min_shape.append(1)
else:
min_shape.append(int(var_range.lower))
opt_shape.append(int(var_val))
max_shape.append(int(var_range.upper))
else:
min_shape.append(dim)
opt_shape.append(dim)
max_shape.append(dim)

return Input(
min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input.dtype
)


def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
"""
Construct torch_tensorrt Inputs based on the module inputs.
The module inputs will have meta data which has the shape and dtype info
Args:
module: Input FX GraphModule
Returns:
Sequence of torch_tensorrt.Input's representing inputs to given module
"""
torchtrt_inputs = []
module_inputs = [node for node in module.graph.nodes if node.op == "placeholder"]
for input in module_inputs:
if input.meta and "val" in input.meta:
input_meta = input.meta["val"]
input_shape = input_meta.size()
if contains_sym_int(input_shape):
torchtrt_inputs.append(construct_dynamic_input(input_meta))
else:
torchtrt_inputs.append(Input(shape=input_shape, dtype=input_meta.dtype))
else:
raise AssertionError(
f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly"
)

return torchtrt_inputs


def run_shape_analysis(
parent_module: torch.fx.GraphModule, inputs: Sequence[Input]
) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]:
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def get_torch_inputs(
if isinstance(input, Input)
]
return [
input.torch_tensor.to(device) for input in inputs if isinstance(input, Input)
input.torch_tensor.to(device) if isinstance(input, Input) else input
for input in inputs
]


Expand Down
Loading
Loading