diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index b321eabcb2..ae93bf344a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -189,10 +189,10 @@ def compile( ) gm = exported_program.module() logger.debug("Input graph: " + str(gm.graph)) - # Apply lowering on the graph module torch_inputs = get_torch_inputs(inputs, device) gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph: " + str(gm.graph)) enabled_precisions = set(enabled_precisions) @@ -308,6 +308,24 @@ def compile_module( f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph." ) + def contains_metadata(gm: torch.fx.GraphModule) -> bool: + for node in gm.graph.nodes: + if node.op != "output" and (not node.meta) and "val" not in node.meta: + logger.warning( + f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior." + ) + return False + return True + + # Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation. + if not contains_metadata(gm): + from torch._inductor.compile_fx import fake_tensor_prop + + torch_inputs = get_torch_inputs(sample_inputs, settings.device) + with torch.no_grad(): + # This fails if the module has data-dependent shape operators. + fake_tensor_prop(gm, torch_inputs) + # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False @@ -366,12 +384,7 @@ def compile_module( ) # Get the submodule inputs for min, opt, max shapes of the graph inputs - submodule_inputs = partitioning.get_submod_inputs( - partitioned_module, - submodule, - sample_inputs, - to_torch_device(settings.device), - ) + submodule_inputs = partitioning.construct_submodule_inputs(submodule) logger.debug( "Submodule name: %s\n Input shapes: %s\n %s", diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 1fa2806181..bade91c553 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -74,7 +74,6 @@ def _pretraced_backend( fake_mode, "allow_non_fake_inputs", True ), fake_mode: repair_input_aliasing(gm) - # Invoke AOTAutograd to translate operators to aten gm = aot_export_joint_simple( gm, diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 45949a1c8d..72998e1917 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -392,6 +392,22 @@ def aten_ops_sigmoid( ) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int) +def aten_ops_symsize_int( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1]) + + def index_dtype_validator(node: Node) -> bool: index = node.args[1] for ind in index: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py index 672fc97351..63ff93b0c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/grid.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -1,13 +1,11 @@ -from typing import Optional, Sequence +from typing import Optional import tensorrt as trt -import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor # nearest, linear, cubic GridSamplerInterpolationMode = { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index db586be65f..dc33129d24 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -90,7 +90,7 @@ def index( # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( - f"Determining whether aten.index constant-index optimization can be invoked" + "Determining whether aten.index constant-index optimization can be invoked" ) is_numpy = all( isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None @@ -123,7 +123,7 @@ def index( return identity_layer.get_output(0) elif len(tensor_indices) == 1: indices_tensor = get_trt_tensor( - ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor" + ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor" ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") @@ -204,7 +204,7 @@ def index( cum_adv_index = cum_adv_index + adv_index multiplier = multiplier * input_shape[adv_indx_indices[i]] cum_adv_index = get_trt_tensor( - ctx, cum_adv_index, name + f"_index_sum_intermediate" + ctx, cum_adv_index, name + "_index_sum_intermediate" ) else: multiplier = get_trt_tensor( @@ -263,7 +263,7 @@ def index( adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1 ): - _LOGGER.debug(f"The indices are continuous in this case") + _LOGGER.debug("The indices are continuous in this case") concat_tensor_reshape.append( get_trt_tensor(ctx, -1, name + "_dynamic_concat") ) @@ -287,7 +287,7 @@ def index( source_ir, ) unfold_tensor = regular_index_shuffle_layer.get_output(0) - _LOGGER.debug(f"The tensor is unfolded now") + _LOGGER.debug("The tensor is unfolded now") _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}") # Transpose folded advanced indexed axis to its original location. @@ -342,7 +342,7 @@ def index( reshape_output = unfold_advanced_shuffle_layer.get_output(0) else: - _LOGGER.debug(f"The indices are not continuous in this case") + _LOGGER.debug("The indices are not continuous in this case") concat_final_tensor = [] concat_final_tensor.append(cum_adv_index_shape_tensor) for i in range(0, rank): diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index ef30b186c1..2d2481936b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -8,7 +8,11 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_positive_dim, + get_trt_tensor, + to_numpy, +) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) @@ -16,6 +20,33 @@ from torch_tensorrt.fx.types import TRTTensor +def shape( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + dim: int, +) -> TRTTensor: + """ + This is the general shape layer implementation in TensorRT. + sym_size.int ops map to addShape layer in TensorRT and returns + the dynamic shape of the tensor optionally taking in a dim argument. + """ + shape_layer = ctx.net.add_shape(input_val) + input_shape = shape_layer.get_output(0) + set_layer_name(shape_layer, target, name + "_shape", source_ir) + + n_dims = len(input_val.shape) + dim = get_positive_dim(dim, n_dims) + dim_tensor = get_trt_tensor(ctx, dim, name + "_dim") + gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0) + set_layer_name(gather_layer, target, name + "_gather", source_ir) + input_shape = gather_layer.get_output(0) + + return input_shape + + def get_shape_with_dynamic_shape( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 49ddb76e2c..6d848c4be3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -3,7 +3,7 @@ import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -17,7 +17,23 @@ def reshape( shape: Sequence[int], ) -> TRTTensor: layer = ctx.net.add_shuffle(input) - layer.reshape_dims = tuple(shape) + if all(isinstance(s, int) for s in shape): + layer.reshape_dims = tuple(shape) + else: + # Convert all the dimensions to trt Tensors. + trt_shape = [] + + for i, s in enumerate(shape): + if isinstance(s, TRTTensor): + trt_shape.append(s) + else: + a = get_trt_tensor(ctx, s, f"{name}_{i}") + trt_shape.append(a) + shape_layer = ctx.net.add_concatenation(inputs=trt_shape) + shape_layer.axis = 0 + shape_layer.name = f"{name}_output_shape" + layer.set_input(1, shape_layer.get_output(0)) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 5f1db00f33..61d71fe9a0 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -69,7 +69,6 @@ def expand( ) -> TRTTensor: shape_rank = len(shape) initial_tensor_rank = len(input_t.shape) - # If the rank of the input tensor is less than the shape's rank, pad with ones if initial_tensor_rank < shape_rank: input_t = prepend_ones( @@ -99,6 +98,7 @@ def expand( stride = tuple( [int(i == o) for i, o in zip(input_tensor_shape, shape)] ) # stride == 1 if dimensions match, 0 otherwise + layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 3313730ec3..594bb4167c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -29,7 +29,7 @@ def upsample( resize_layer.scales = [1.0, 1.0] + list(scale_factors) else: raise RuntimeError( - f"At least one of out_shape and scale_factors should be specified." + "At least one of out_shape and scale_factors should be specified." ) # interpolate mode diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py index 31a55099c2..0ffc6d3c76 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, List import torch @@ -29,3 +29,24 @@ def get_tensor_placeholders( ] return placeholders + + +def get_metadata( + gm: torch.fx.GraphModule, target_op: Any +) -> List[torch._ops.OpOverload]: + """ + Return the list which has the metadata of all the target_op nodes present in the graph. + """ + return [node.meta for node in gm.graph.nodes if node.target == target_op] + + +def set_metadata( + gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload] +) -> None: + """ + Return the list which has the metadata of all the target_op nodes present in the graph. + """ + target_nodes = [node for node in gm.graph.nodes if node.target == target_op] + assert len(target_nodes) == len(metadata) + for idx, node in enumerate(target_nodes): + node.meta = metadata[idx] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py index e2ef051f06..b2da354122 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -1,9 +1,11 @@ import logging -from typing import Callable, List, Sequence, Tuple +from typing import List, Sequence import torch from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, + get_metadata, + set_metadata, ) logger = logging.getLogger(__name__) @@ -13,27 +15,25 @@ def view_to_reshape( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] ) -> torch.fx.GraphModule: """Replace aten.view with an equivalent implementation which avoids Tensor memory issues""" - orig, replacement = view_replacement() - - if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): - gm = clean_up_graph_after_modifications(gm) - logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") - - return gm - - -def view_replacement() -> Tuple[ - torch.fx.GraphModule, - Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], -]: - """Constructs the original and replacement functions for view""" + orig_op = torch.ops.aten.view.default + replacement_op = torch.ops.aten.reshape.default # Original graph def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: - return torch.ops.aten.view.default(input, shape) + return orig_op(input, shape) # Replacement graph def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: - return torch.ops.aten.reshape.default(input, shape) + return replacement_op(input, shape) - return orig, replacement + # Store metadata of the orig_op + metadata = get_metadata(gm, orig_op) + + if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") + + # Copy the orig_op's metadata to the replacement op + set_metadata(gm, replacement_op, metadata) + + return gm diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py index 1a8cc94099..25487da065 100644 --- a/py/torch_tensorrt/dynamo/partitioning/__init__.py +++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py @@ -1,3 +1,7 @@ from ._adjacency_partitioner import partition as fast_partition from ._global_partitioner import partition as global_partition -from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis +from .common import ( + construct_submodule_inputs, + get_graph_converter_support, + run_shape_analysis, +) diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 8348738afa..270973c8c3 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -4,11 +4,99 @@ import torch from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import DEBUG -from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic logger = logging.getLogger(__name__) +def contains_sym_int(tensor: torch.Tensor) -> bool: + """ + Returns true if the given tensor has symbolic shape. + """ + for dim in tensor: + if isinstance(dim, torch.SymInt): + return True + return False + + +def construct_dynamic_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input: + """ + Constructs a torch_tensorrt.Input based on a symbolic input + Args: + input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values) + Returns: + A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input. + """ + min_shape = [] + opt_shape = [] + max_shape = [] + for dim in input_shape: + if isinstance(dim, torch.SymInt): + node = dim.node + expr = node.expr + shape_env = node.shape_env + var_range = shape_env.var_to_range.get(expr, None) + var_val = shape_env.var_to_val.get(expr, None) + assert var_range, var_val + # Torchdynamo 0/1 specialization outlier + if var_range.lower == 2: + min_shape.append(1) + else: + min_shape.append(int(var_range.lower)) + opt_shape.append(int(var_val)) + max_shape.append(int(var_range.upper)) + else: + min_shape.append(dim) + opt_shape.append(dim) + max_shape.append(dim) + + return Input( + min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input_dtype + ) + + +def get_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input: + """ + Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs + """ + if contains_sym_int(input_shape): + return construct_dynamic_input(input_shape, input_dtype) + else: + return Input(shape=input_shape, dtype=input_dtype) + + +def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: + """ + Construct torch_tensorrt Inputs based on the module inputs. + The module inputs will have meta data which has the shape and dtype info + Args: + module: Input FX GraphModule + Returns: + Sequence of torch_tensorrt.Input's representing inputs to given module + """ + torchtrt_inputs = [] + module_inputs = [node for node in module.graph.nodes if node.op == "placeholder"] + for input in module_inputs: + if input.meta: + if "val" in input.meta: + input_meta = input.meta["val"] + input_shape = input_meta.size() + torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + elif "tensor_meta" in input.meta: + input_meta = input.meta["tensor_meta"] + input_shape = input_meta.shape + torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + else: + raise AssertionError( + f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly" + ) + else: + raise AssertionError( + f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly" + ) + + return torchtrt_inputs + + def run_shape_analysis( parent_module: torch.fx.GraphModule, inputs: Sequence[Input] ) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]: @@ -46,80 +134,6 @@ def get_submodule_io( return submod_inputs_shape_map, submod_outputs_shape_map -def get_submod_inputs( - mod: torch.fx.GraphModule, - submod: torch.fx.GraphModule, - inputs: Sequence[Input], - device: torch.device, -) -> Optional[Sequence[torch.Tensor]]: - """Helper function to get inputs to a Torch submodule - - Args: - mod: Parent FX GraphModule - submod: Child FX GraphModule - inputs: Sample inputs to parent module - Returns: - Sequence of Tensors representing inputs to child module - """ - acc_inputs: Any = None - - def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None: - nonlocal acc_inputs - acc_inputs = inputs - return - - # Register a hook to capture submodule input - handle = submod.register_forward_pre_hook(get_input) - # Iterate over min, opt, max shapes for dynamic inputs - inputs_map = {} - - if input_is_dynamic(inputs): - for mode in ["min_shape", "opt_shape", "max_shape"]: - torch_inputs = get_torch_inputs(inputs, device, mode) - mod(*torch_inputs) - inputs_map[mode] = acc_inputs - handle.remove() - else: - torch_inputs = get_torch_inputs(inputs, device) - mod(*torch_inputs) - handle.remove() - assert isinstance(acc_inputs, tuple) - return [ - Input(shape=acc_input.shape, dtype=acc_input.dtype) - for acc_input in acc_inputs - ] - - num_submodule_inputs = ( - len(inputs_map["min_shape"]) if inputs_map["min_shape"] else 0 - ) - submodule_inputs = [] - for idx in range(num_submodule_inputs): - if not isinstance(inputs_map["min_shape"][idx], torch.Tensor): - input_val = torch.tensor(inputs_map["opt_shape"][idx], dtype=torch.int32) - logger.warning( - "Detected a zero-dimensional input. This might be a shape tensor input which is not currently supported. This might result in undefined behavior" - ) - submodule_inputs.append( - Input( - shape=[1], - torch_tensor=input_val, - dtype=input_val.dtype, - ) - ) - else: - submodule_inputs.append( - Input( - min_shape=inputs_map["min_shape"][idx].shape, - opt_shape=inputs_map["opt_shape"][idx].shape, - max_shape=inputs_map["max_shape"][idx].shape, - torch_tensor=inputs_map["opt_shape"][idx], - dtype=inputs_map["opt_shape"][idx].dtype, - ) - ) - - return submodule_inputs - - def get_graph_converter_support( graph_module: torch.fx.GraphModule, verbose: bool = DEBUG, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 22590fe73d..549636b3c7 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -88,7 +88,8 @@ def get_torch_inputs( if isinstance(input, Input) ] return [ - input.torch_tensor.to(device) for input in inputs if isinstance(input, Input) + input.torch_tensor.to(device) if isinstance(input, Input) else input + for input in inputs ] diff --git a/tests/py/dynamo/conversion/test_sym_size.py b/tests/py/dynamo/conversion/test_sym_size.py new file mode 100644 index 0000000000..35bf75a509 --- /dev/null +++ b/tests/py/dynamo/conversion/test_sym_size.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestSymSizeConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4),), + ] + ) + def test_sym_size_batch(self, input_shape): + class BatchDim(nn.Module): + def forward(self, x): + return torch.ops.aten.sym_size.int(x, 0) + + inputs = [torch.randn(*input_shape)] + self.run_test( + BatchDim(), + inputs, + ) + + @parameterized.expand( + [ + ((3, 2, 4),), + ] + ) + def test_sym_size_non_batch(self, input_shape): + class NonBatchDim(nn.Module): + def forward(self, x): + return torch.ops.aten.sym_size.int(x, 1) + + inputs = [torch.randn(*input_shape)] + self.run_test( + NonBatchDim(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index ceb4a6dd2c..822ee468a9 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -3,9 +3,8 @@ import pytest import timm import torch -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity - import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -65,7 +64,7 @@ def forward(self, x): @pytest.mark.unit def test_base_dynamic_fallback(ir): """ - Tests the model (which is fully convertible) with dynamic shapes + Tests the model with dynamic shapes where torch.abs op is forced to run in PyTorch """ class MyModule(torch.nn.Module): @@ -114,3 +113,53 @@ def forward(self, x): with torch.no_grad(): torch.cuda.empty_cache() + + +@pytest.mark.unit +def test_view(ir): + """ + Tests the model (which is fully convertible) with dynamic shapes + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + input_shape = x.size() + y = x.view(input_shape[0], -1) + return y + + model = MyModule().eval().cuda() + input = torch.randn((6, 3, 4)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 3, 4), + opt_shape=(4, 3, 4), + max_shape=(8, 3, 4), + dtype=torch.float32, + name="x", + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() diff --git a/versions.py b/versions.py index 772737aab7..db418a06d2 100644 --- a/versions.py +++ b/versions.py @@ -1,11 +1,10 @@ -import yaml -import re import os +import re import subprocess - from datetime import datetime from pathlib import Path -from typing import List + +import yaml __version__ = "0.0.0" __cuda_version__ = "0.0"