Skip to content

Commit

Permalink
feat: Implement symbolic shape propagation, sym_size converter (#2473)
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Apr 16, 2024
1 parent b76024d commit 7120d75
Show file tree
Hide file tree
Showing 16 changed files with 325 additions and 120 deletions.
27 changes: 20 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ def compile(
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
torch_inputs = get_torch_inputs(inputs, device)
gm = apply_lowering_passes(gm, torch_inputs)

logger.debug("Lowered Input graph: " + str(gm.graph))

enabled_precisions = set(enabled_precisions)
Expand Down Expand Up @@ -313,6 +313,24 @@ def compile_module(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

def contains_metadata(gm: torch.fx.GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op != "output" and (not node.meta) and "val" not in node.meta:
logger.warning(
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
)
return False
return True

# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
if not contains_metadata(gm):
from torch._inductor.compile_fx import fake_tensor_prop

torch_inputs = get_torch_inputs(sample_inputs, settings.device)
with torch.no_grad():
# This fails if the module has data-dependent shape operators.
fake_tensor_prop(gm, torch_inputs)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

Expand Down Expand Up @@ -371,12 +389,7 @@ def compile_module(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module,
submodule,
sample_inputs,
to_torch_device(settings.device),
)
submodule_inputs = partitioning.construct_submodule_inputs(submodule)

logger.debug(
"Submodule name: %s\n Input shapes: %s\n %s",
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
Expand Down
16 changes: 16 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,22 @@ def aten_ops_sigmoid(
)


@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
def aten_ops_symsize_int(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])


def index_dtype_validator(node: Node) -> bool:
index = node.args[1]
for ind in index:
Expand Down
6 changes: 2 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/grid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Optional, Sequence
from typing import Optional

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.types import TRTTensor

# nearest, linear, cubic
GridSamplerInterpolationMode = {
Expand Down
33 changes: 32 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,45 @@
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
get_trt_tensor,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def shape(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: int,
) -> TRTTensor:
"""
This is the general shape layer implementation in TensorRT.
sym_size.int ops map to addShape layer in TensorRT and returns
the dynamic shape of the tensor optionally taking in a dim argument.
"""
shape_layer = ctx.net.add_shape(input_val)
input_shape = shape_layer.get_output(0)
set_layer_name(shape_layer, target, name + "_shape", source_ir)

n_dims = len(input_val.shape)
dim = get_positive_dim(dim, n_dims)
dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0)
set_layer_name(gather_layer, target, name + "_gather", source_ir)
input_shape = gather_layer.get_output(0)

return input_shape


def get_shape_with_dynamic_shape(
ctx: ConversionContext,
target: Target,
Expand Down
20 changes: 18 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

Expand All @@ -17,7 +17,23 @@ def reshape(
shape: Sequence[int],
) -> TRTTensor:
layer = ctx.net.add_shuffle(input)
layer.reshape_dims = tuple(shape)
if all(isinstance(s, int) for s in shape):
layer.reshape_dims = tuple(shape)
else:
# Convert all the dimensions to trt Tensors.
trt_shape = []

for i, s in enumerate(shape):
if isinstance(s, TRTTensor):
trt_shape.append(s)
else:
a = get_trt_tensor(ctx, s, f"{name}_{i}")
trt_shape.append(a)
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
shape_layer.axis = 0
shape_layer.name = f"{name}_output_shape"
layer.set_input(1, shape_layer.get_output(0))

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def expand(
) -> TRTTensor:
shape_rank = len(shape)
initial_tensor_rank = len(input_t.shape)

# If the rank of the input tensor is less than the shape's rank, pad with ones
if initial_tensor_rank < shape_rank:
input_t = prepend_ones(
Expand Down Expand Up @@ -99,6 +98,7 @@ def expand(
stride = tuple(
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
) # stride == 1 if dimensions match, 0 otherwise

layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def upsample(
resize_layer.scales = [1.0, 1.0] + list(scale_factors)
else:
raise RuntimeError(
f"At least one of out_shape and scale_factors should be specified."
"At least one of out_shape and scale_factors should be specified."
)

# interpolate mode
Expand Down
23 changes: 22 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, List

import torch

Expand Down Expand Up @@ -29,3 +29,24 @@ def get_tensor_placeholders(
]

return placeholders


def get_metadata(
gm: torch.fx.GraphModule, target_op: Any
) -> List[torch._ops.OpOverload]:
"""
Return the list which has the metadata of all the target_op nodes present in the graph.
"""
return [node.meta for node in gm.graph.nodes if node.target == target_op]


def set_metadata(
gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
) -> None:
"""
Return the list which has the metadata of all the target_op nodes present in the graph.
"""
target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
assert len(target_nodes) == len(metadata)
for idx, node in enumerate(target_nodes):
node.meta = metadata[idx]
36 changes: 18 additions & 18 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Callable, List, Sequence, Tuple
from typing import List, Sequence

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_metadata,
set_metadata,
)

logger = logging.getLogger(__name__)
Expand All @@ -13,27 +15,25 @@ def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]:
"""Constructs the original and replacement functions for view"""
orig_op = torch.ops.aten.view.default
replacement_op = torch.ops.aten.reshape.default

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
return orig_op(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)
return replacement_op(input, shape)

return orig, replacement
# Store metadata of the orig_op
metadata = get_metadata(gm, orig_op)

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

# Copy the orig_op's metadata to the replacement op
set_metadata(gm, replacement_op, metadata)

return gm
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
from .common import (
construct_submodule_inputs,
get_graph_converter_support,
run_shape_analysis,
)
Loading

0 comments on commit 7120d75

Please sign in to comment.