Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement symbolic shape propagation, sym_size converter #2473

Merged
merged 49 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
d06c74a
chore: Switch to new export apis
peri044 Oct 3, 2023
47e0997
chore: rebase with main
peri044 Oct 9, 2023
fd29fe0
Merge branch 'main' into export_2.2
peri044 Oct 11, 2023
ad3b031
feat: Add support for dynamic shapes and remove constraints API
peri044 Oct 19, 2023
1582b72
chore: add dynamic shape support for certain converters
peri044 Oct 23, 2023
4d01545
chore: minor updates
peri044 Oct 25, 2023
6731a57
chore: updates
peri044 Oct 26, 2023
a8a194b
chore: rebase with main
peri044 Nov 15, 2023
0b60aae
chore: add sym int converter
peri044 Nov 15, 2023
634612f
feat: Replace the existing shape propagation with symbolic shape prop…
peri044 Nov 16, 2023
93edba4
chore: fix imports
peri044 Nov 16, 2023
7ad9272
chore: fix imports
peri044 Nov 16, 2023
f444d54
chore: updates
peri044 Nov 21, 2023
6e5c582
chore: change device calls
peri044 Nov 28, 2023
83791f8
chore: fix metadata check
peri044 Dec 5, 2023
8375996
chore: rebase with main
peri044 Dec 15, 2023
aba91fa
Merge branch 'main' into dyn_2.2
peri044 Dec 22, 2023
16394d9
chore: minor fixes
peri044 Jan 7, 2024
b9a7ccd
chore: Add sym_size converter tests
peri044 Jan 8, 2024
15cc643
chore: Update test utilities
peri044 Jan 8, 2024
5234d74
chore: add testcase for sym_size.int
peri044 Jan 8, 2024
fd2dae1
Merge branch 'main' into dyn_2.2
peri044 Jan 26, 2024
51e8bb7
chore: revert output type change
peri044 Jan 26, 2024
19c3fad
chore: add update_metadata utility
peri044 Jan 27, 2024
ed48551
chore: change debug to warning if the graph does not have metadata
peri044 Jan 27, 2024
9aff04b
chore: gpt2 changes + linting
peri044 Feb 7, 2024
440fcd5
chore: gpt2 changes + linting
peri044 Feb 7, 2024
a2d38f3
chore: rebase with main
peri044 Feb 7, 2024
002db3c
chore: add fallback option if val is missing in metadata
peri044 Feb 7, 2024
00cd17b
chore: tmp changes
peri044 Feb 13, 2024
6ac70cd
chore: tmp changes
peri044 Feb 13, 2024
b827070
Merge branch 'main' into dyn_2.2
peri044 Feb 16, 2024
8f9bca0
Merge branch 'main' into dyn_2.2
peri044 Feb 21, 2024
cd86660
feat: Add save API for torch-trt compiled models
peri044 Mar 14, 2024
3ece71b
chore: resolve merge conflicts
peri044 Mar 15, 2024
1fa1771
Merge branch 'main' into dyn_2.2
peri044 Mar 15, 2024
febf05b
Merge branch 'save' into dyn_2.2
peri044 Mar 15, 2024
eab0dba
chore: Fix save failures
peri044 Mar 18, 2024
b191d62
chore: update to 2.3 rc build
peri044 Mar 18, 2024
5f34d4f
chore: minor fixes
peri044 Mar 19, 2024
ce606fe
chore: rebase with release/2.3 branch
peri044 Mar 19, 2024
8674a3c
chore: minor fixes
peri044 Mar 19, 2024
f4e8fe9
chore: remove duplicate bert test case
peri044 Mar 20, 2024
4ae6ab9
chore: remove comments
peri044 Mar 20, 2024
c14f28d
Merge branch 'save' into dyn_2.2
peri044 Mar 20, 2024
4188173
chore: rebase with release/2.3
peri044 Apr 2, 2024
78f7eb5
chore: updates
peri044 Apr 2, 2024
e9b649d
chore: revert changes
peri044 Apr 5, 2024
978c039
Merge branch 'release/2.3' into dyn_2.2
peri044 Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def compile(
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
torch_inputs = get_torch_inputs(inputs, device)
gm = apply_lowering_passes(gm, torch_inputs)

logger.debug("Lowered Input graph: " + str(gm.graph))

enabled_precisions = set(enabled_precisions)
Expand Down Expand Up @@ -308,6 +308,24 @@ def compile_module(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

def contains_metadata(gm: torch.fx.GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op != "output" and (not node.meta) and "val" not in node.meta:
logger.warning(
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
)
return False
return True

# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
if not contains_metadata(gm):
from torch._inductor.compile_fx import fake_tensor_prop

torch_inputs = get_torch_inputs(sample_inputs, settings.device)
with torch.no_grad():
# This fails if the module has data-dependent shape operators.
fake_tensor_prop(gm, torch_inputs)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

Expand Down Expand Up @@ -366,12 +384,7 @@ def compile_module(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module,
submodule,
sample_inputs,
to_torch_device(settings.device),
)
submodule_inputs = partitioning.construct_submodule_inputs(submodule)

logger.debug(
"Submodule name: %s\n Input shapes: %s\n %s",
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
Expand Down
16 changes: 16 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,22 @@ def aten_ops_sigmoid(
)


@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
def aten_ops_symsize_int(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])


def index_dtype_validator(node: Node) -> bool:
index = node.args[1]
for ind in index:
Expand Down
6 changes: 2 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/grid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Optional, Sequence
from typing import Optional

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.types import TRTTensor

# nearest, linear, cubic
GridSamplerInterpolationMode = {
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def index(
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
f"Determining whether aten.index constant-index optimization can be invoked"
"Determining whether aten.index constant-index optimization can be invoked"
)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
Expand Down Expand Up @@ -123,7 +123,7 @@ def index(
return identity_layer.get_output(0)
elif len(tensor_indices) == 1:
indices_tensor = get_trt_tensor(
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor"
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
Expand Down Expand Up @@ -204,7 +204,7 @@ def index(
cum_adv_index = cum_adv_index + adv_index
multiplier = multiplier * input_shape[adv_indx_indices[i]]
cum_adv_index = get_trt_tensor(
ctx, cum_adv_index, name + f"_index_sum_intermediate"
ctx, cum_adv_index, name + "_index_sum_intermediate"
)
else:
multiplier = get_trt_tensor(
Expand Down Expand Up @@ -263,7 +263,7 @@ def index(
adv_indx_count
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
):
_LOGGER.debug(f"The indices are continuous in this case")
_LOGGER.debug("The indices are continuous in this case")
concat_tensor_reshape.append(
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
)
Expand All @@ -287,7 +287,7 @@ def index(
source_ir,
)
unfold_tensor = regular_index_shuffle_layer.get_output(0)
_LOGGER.debug(f"The tensor is unfolded now")
_LOGGER.debug("The tensor is unfolded now")
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")

# Transpose folded advanced indexed axis to its original location.
Expand Down Expand Up @@ -342,7 +342,7 @@ def index(
reshape_output = unfold_advanced_shuffle_layer.get_output(0)

else:
_LOGGER.debug(f"The indices are not continuous in this case")
_LOGGER.debug("The indices are not continuous in this case")
concat_final_tensor = []
concat_final_tensor.append(cum_adv_index_shape_tensor)
for i in range(0, rank):
Expand Down
33 changes: 32 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,45 @@
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
get_trt_tensor,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def shape(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: int,
) -> TRTTensor:
"""
This is the general shape layer implementation in TensorRT.
sym_size.int ops map to addShape layer in TensorRT and returns
the dynamic shape of the tensor optionally taking in a dim argument.
"""
shape_layer = ctx.net.add_shape(input_val)
input_shape = shape_layer.get_output(0)
set_layer_name(shape_layer, target, name + "_shape", source_ir)

n_dims = len(input_val.shape)
dim = get_positive_dim(dim, n_dims)
dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0)
set_layer_name(gather_layer, target, name + "_gather", source_ir)
input_shape = gather_layer.get_output(0)

return input_shape


def get_shape_with_dynamic_shape(
ctx: ConversionContext,
target: Target,
Expand Down
20 changes: 18 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

Expand All @@ -17,7 +17,23 @@ def reshape(
shape: Sequence[int],
) -> TRTTensor:
layer = ctx.net.add_shuffle(input)
layer.reshape_dims = tuple(shape)
if all(isinstance(s, int) for s in shape):
layer.reshape_dims = tuple(shape)
else:
# Convert all the dimensions to trt Tensors.
trt_shape = []

for i, s in enumerate(shape):
if isinstance(s, TRTTensor):
trt_shape.append(s)
else:
a = get_trt_tensor(ctx, s, f"{name}_{i}")
trt_shape.append(a)
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
shape_layer.axis = 0
shape_layer.name = f"{name}_output_shape"
layer.set_input(1, shape_layer.get_output(0))

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def expand(
) -> TRTTensor:
shape_rank = len(shape)
initial_tensor_rank = len(input_t.shape)

# If the rank of the input tensor is less than the shape's rank, pad with ones
if initial_tensor_rank < shape_rank:
input_t = prepend_ones(
Expand Down Expand Up @@ -99,6 +98,7 @@ def expand(
stride = tuple(
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
) # stride == 1 if dimensions match, 0 otherwise

layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def upsample(
resize_layer.scales = [1.0, 1.0] + list(scale_factors)
else:
raise RuntimeError(
f"At least one of out_shape and scale_factors should be specified."
"At least one of out_shape and scale_factors should be specified."
)

# interpolate mode
Expand Down
23 changes: 22 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, List

import torch

Expand Down Expand Up @@ -29,3 +29,24 @@ def get_tensor_placeholders(
]

return placeholders


def get_metadata(
gm: torch.fx.GraphModule, target_op: Any
) -> List[torch._ops.OpOverload]:
"""
Return the list which has the metadata of all the target_op nodes present in the graph.
"""
return [node.meta for node in gm.graph.nodes if node.target == target_op]


def set_metadata(
gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
) -> None:
"""
Return the list which has the metadata of all the target_op nodes present in the graph.
"""
target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
assert len(target_nodes) == len(metadata)
for idx, node in enumerate(target_nodes):
node.meta = metadata[idx]
36 changes: 18 additions & 18 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Callable, List, Sequence, Tuple
from typing import List, Sequence

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_metadata,
set_metadata,
)

logger = logging.getLogger(__name__)
Expand All @@ -13,27 +15,25 @@ def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]:
"""Constructs the original and replacement functions for view"""
orig_op = torch.ops.aten.view.default
replacement_op = torch.ops.aten.reshape.default

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
return orig_op(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)
return replacement_op(input, shape)

return orig, replacement
# Store metadata of the orig_op
metadata = get_metadata(gm, orig_op)

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

# Copy the orig_op's metadata to the replacement op
set_metadata(gm, replacement_op, metadata)

return gm
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
from .common import (
construct_submodule_inputs,
get_graph_converter_support,
run_shape_analysis,
)
Loading
Loading