Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Set return type of compilation to ExportedProgram [release/2.2] #2607

Merged
merged 16 commits into from
Jan 31, 2024

Conversation

peri044
Copy link
Collaborator

@peri044 peri044 commented Jan 18, 2024

Description

Adds changes from following PRs

  1. chore: Set default return type to ExportedProgram #2575
  2. Clean up AWS credentials #2592
  3. Grant write permission to token #2591

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

chore: Add output_format flag

chore: updates

chore: additional fixes

chore: add break
@github-actions github-actions bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jan 19, 2024
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-29 23:29:29.597558+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-29 23:31:24.953996+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-29 23:29:29.605558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-29 23:31:25.049031+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-29 23:29:29.605558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-29 23:31:25.238752+00:00
@@ -202,13 +202,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-29 23:29:29.605558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-29 23:31:25.267665+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-29 23:29:29.605558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-29 23:31:25.483707+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-29 23:29:29.609558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-29 23:31:25.565457+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-29 23:29:29.609558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-29 23:31:25.946781+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-29 23:29:29.609558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-29 23:31:25.953508+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-29 23:29:29.609558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-29 23:31:25.955652+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-29 23:29:29.609558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-29 23:31:25.986916+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-29 23:29:29.609558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-29 23:31:26.228007+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-29 23:29:29.613558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-29 23:31:26.650682+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-29 23:29:29.613558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-29 23:31:26.690231+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-29 23:29:29.613558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-29 23:31:26.712075+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-29 23:29:29.613558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-29 23:31:27.024505+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-29 23:29:29.613558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-29 23:31:27.267947+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-29 23:29:29.613558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-29 23:31:27.286915+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-29 23:29:29.617558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-29 23:31:27.711234+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-29 23:29:29.621558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-29 23:31:28.669277+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-29 23:29:29.621558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-29 23:31:29.083167+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-29 23:29:29.621558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-29 23:31:29.399662+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-30 01:40:37.564865+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-30 01:42:24.853158+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-30 01:40:37.572865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-30 01:42:24.950085+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-30 01:40:37.572865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-30 01:42:25.141211+00:00
@@ -202,13 +202,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-30 01:40:37.572865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-30 01:42:25.155183+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-30 01:40:37.572865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-30 01:42:25.376670+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-30 01:40:37.572865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-30 01:42:25.465586+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-30 01:40:37.576865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-30 01:42:25.803478+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-30 01:40:37.576865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-30 01:42:25.825177+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-30 01:40:37.576865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-30 01:42:25.832747+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-30 01:40:37.576865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-30 01:42:25.860729+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-30 01:40:37.576865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-30 01:42:26.085561+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-30 01:40:37.580865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-30 01:42:26.475816+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-30 01:40:37.580865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-30 01:42:26.540057+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-30 01:40:37.580865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-30 01:42:26.558505+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-30 01:40:37.580865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-30 01:42:26.869388+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-30 01:40:37.580865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-30 01:42:27.076018+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-30 01:40:37.580865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-30 01:42:27.101330+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-30 01:40:37.584865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-30 01:42:27.514044+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-30 01:40:37.588865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-30 01:42:28.416155+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-30 01:40:37.588865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-30 01:42:28.723118+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-30 01:40:37.588865+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-30 01:42:29.254450+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-30 21:43:45.752044+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-30 21:45:36.002970+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-30 21:43:45.760044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-30 21:45:36.101949+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-30 21:43:45.760044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-30 21:45:36.307640+00:00
@@ -202,13 +202,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-30 21:43:45.760044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-30 21:45:36.307817+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-30 21:43:45.760044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-30 21:45:36.528883+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-30 21:43:45.760044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-30 21:45:36.623742+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-30 21:43:45.764044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-30 21:45:36.985733+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-30 21:43:45.764044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-30 21:45:36.985989+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-30 21:43:45.764044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-30 21:45:36.992887+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-30 21:43:45.764044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-30 21:45:37.023377+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-30 21:43:45.764044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-30 21:45:37.248197+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-30 21:43:45.764044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-30 21:45:37.667856+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-30 21:43:45.768044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-30 21:45:37.710480+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-30 21:43:45.768044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-30 21:45:37.735459+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-30 21:43:45.768044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-30 21:45:38.063137+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-30 21:43:45.768044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-30 21:45:38.261736+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-30 21:43:45.768044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-30 21:45:38.300289+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-30 21:43:45.768044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-30 21:45:38.864859+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-30 21:43:45.772044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-30 21:45:39.631277+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-30 21:43:45.776044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-30 21:45:39.923873+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-30 21:43:45.776044+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-30 21:45:40.501228+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-30 23:15:12.840472+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-30 23:17:05.765941+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-30 23:15:12.844472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-30 23:17:05.861080+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-30 23:15:12.848472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-30 23:17:06.068463+00:00
@@ -202,13 +202,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-30 23:15:12.844472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-30 23:17:06.089259+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-30 23:15:12.848472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-30 23:17:06.318414+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-30 23:15:12.848472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-30 23:17:06.396066+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-30 23:17:06.789169+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-30 23:17:06.791324+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-30 23:17:06.795276+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-30 23:17:06.829382+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-30 23:17:07.061823+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-30 23:17:07.469900+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-30 23:17:07.536811+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-30 23:17:07.580966+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-30 23:17:07.924007+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-30 23:15:12.852472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-30 23:17:08.103445+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-30 23:15:12.856472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-30 23:17:08.118812+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-30 23:15:12.856472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-30 23:17:08.578702+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-30 23:15:12.860472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-30 23:17:09.555866+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-30 23:15:12.860472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-30 23:17:09.832349+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-30 23:15:12.860472+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-30 23:17:10.362729+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

@github-actions github-actions bot added the component: build system Issues re: Build system label Jan 31, 2024
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-31 00:09:31.678685+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-31 00:11:25.909053+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-31 00:09:31.682685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-31 00:11:26.003114+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-31 00:09:31.686685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-31 00:11:26.186165+00:00
@@ -202,13 +202,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-31 00:09:31.682685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-31 00:11:26.214936+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-31 00:09:31.686685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-31 00:11:26.438564+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-31 00:09:31.686685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-31 00:11:26.498837+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-31 00:09:31.686685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-31 00:11:26.853756+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-31 00:09:31.686685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-31 00:11:26.863680+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-31 00:09:31.686685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-31 00:11:26.867636+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-31 00:09:31.690685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-31 00:11:26.897584+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-31 00:09:31.690685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-31 00:11:27.131421+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-31 00:09:31.690685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-31 00:11:27.533909+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-31 00:09:31.690685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-31 00:11:27.561743+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-31 00:09:31.690685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-31 00:11:27.589095+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-31 00:09:31.690685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-31 00:11:27.773013+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-31 00:09:31.690685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-31 00:11:28.092992+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-31 00:09:31.690685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-31 00:11:28.165440+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-31 00:09:31.694685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-31 00:11:28.586259+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-31 00:09:31.698685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-31 00:11:29.294846+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-31 00:09:31.698685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-31 00:11:29.700875+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-31 00:09:31.698685+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-31 00:11:30.258126+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-31 00:17:16.933022+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-31 00:19:09.220854+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-31 00:17:16.941022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-31 00:19:09.318981+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-31 00:17:16.941022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-31 00:19:09.517246+00:00
@@ -202,13 +202,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-31 00:17:16.941022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-31 00:19:09.531697+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-31 00:17:16.941022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-31 00:19:09.748835+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-31 00:17:16.941022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-31 00:19:09.836900+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-31 00:17:16.945022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-31 00:19:10.183159+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-31 00:17:16.945022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-31 00:19:10.198831+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-31 00:17:16.945022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-31 00:19:10.205789+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-31 00:17:16.945022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-31 00:19:10.228438+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-31 00:17:16.945022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-31 00:19:10.452215+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-31 00:17:16.949022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-31 00:19:10.860397+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-31 00:17:16.949022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-31 00:19:10.881686+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-31 00:17:16.949022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-31 00:19:10.908246+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-31 00:17:16.949022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-31 00:19:11.206768+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-31 00:17:16.949022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-31 00:19:11.444141+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-31 00:17:16.949022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-31 00:19:11.470242+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-31 00:17:16.953022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-31 00:19:11.919751+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-31 00:17:16.953022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-31 00:19:12.746113+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-31 00:17:16.957022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-31 00:19:13.194652+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-31 00:17:16.957022+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-31 00:19:13.664728+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-31 00:27:05.067360+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-31 00:28:53.276093+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-31 00:27:05.071360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-31 00:28:53.379906+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-31 00:27:05.075360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-31 00:28:53.571569+00:00
@@ -202,13 +202,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-31 00:27:05.071360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-31 00:28:53.588519+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-31 00:27:05.075360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-31 00:28:53.809780+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-31 00:27:05.075360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-31 00:28:53.897553+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-31 00:27:05.075360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-31 00:28:54.236345+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-31 00:28:54.249455+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-31 00:28:54.255031+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-31 00:28:54.284745+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-31 00:28:54.508103+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-31 00:28:54.879960+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-31 00:28:54.942633+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-31 00:28:54.975228+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-31 00:28:55.162885+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-31 00:28:55.450681+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-31 00:27:05.079360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-31 00:28:55.503707+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-31 00:27:05.083360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-31 00:28:56.061610+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-31 00:27:05.087360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-31 00:28:56.742627+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-31 00:27:05.087360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-31 00:28:57.223711+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-31 00:27:05.087360+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-31 00:28:57.506540+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-31 00:45:45.923759+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-01-31 00:48:49.607123+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-31 00:45:45.927759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-01-31 00:48:49.705885+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-31 00:45:45.927759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-01-31 00:48:49.878826+00:00
@@ -202,13 +202,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-31 00:45:45.927759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-01-31 00:48:49.908491+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-31 00:45:45.931759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-01-31 00:48:50.137690+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-31 00:45:45.931759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-01-31 00:48:50.225585+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-31 00:45:45.931759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-01-31 00:48:50.569988+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-31 00:45:45.931759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-01-31 00:48:50.570986+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-31 00:45:45.931759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-01-31 00:48:50.574439+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-31 00:45:45.931759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-01-31 00:48:50.605714+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-31 00:45:45.931759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-01-31 00:48:50.828338+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-31 00:45:45.935759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-01-31 00:48:51.223278+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-31 00:45:45.935759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-01-31 00:48:51.278595+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-31 00:45:45.935759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-01-31 00:48:51.347285+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-31 00:45:45.935759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-01-31 00:48:51.661386+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-31 00:45:45.935759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-01-31 00:48:51.780751+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-31 00:45:45.935759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-01-31 00:48:51.873287+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-31 00:45:45.939759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-01-31 00:48:52.316229+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-31 00:45:45.943759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-01-31 00:48:53.051354+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-31 00:45:45.943759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-01-31 00:48:53.689620+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-31 00:45:45.943759+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-01-31 00:48:54.072639+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage labels Jan 31, 2024
@peri044 peri044 merged commit 18a462f into release/2.2 Jan 31, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: fx component: lowering Issues re: The lowering / preprocessing passes component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants