Skip to content

Commit

Permalink
chore: Mypy and bug fixes for enums, removing uses of FX dtypes in
Browse files Browse the repository at this point in the history
dynamo

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 22, 2024
1 parent 950a55f commit e9ef3ca
Show file tree
Hide file tree
Showing 55 changed files with 775 additions and 568 deletions.
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
exclude: ^.github/actions/assigner/dist
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: trailing-whitespace
Expand All @@ -16,38 +16,38 @@ repos:
- --fix=lf
exclude: ^docs
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v16.0.6
rev: v18.1.1
hooks:
- id: clang-format
types_or: [c++, c, cuda]
- repo: https://github.com/keith/pre-commit-buildifier
rev: 6.1.0.2
rev: 6.4.0
hooks:
- id: buildifier
args:
- --warnings=all
- id: buildifier-lint
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.13
rev: v0.16
hooks:
- id: validate-pyproject
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.4.1'
rev: 'v1.9.0'
hooks:
- id: mypy
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.278
rev: v0.3.3
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.3.0
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand Down
70 changes: 37 additions & 33 deletions py/torch_tensorrt/_Device.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import sys
from typing import Any, Optional, Tuple
from __future__ import annotations

import logging
import sys
from typing import Any, Optional, Tuple

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

import tensorrt as trt
import torch
from torch_tensorrt._enums import DeviceType
from torch_tensorrt._features import ENABLED_FEATURES

import tensorrt as trt


class Device(object):
"""
Defines a device that can be used to specify target devices for engines
Expand All @@ -24,7 +27,9 @@ class Device(object):
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""

device_type: Optional[DeviceType] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
device_type: DeviceType = (
DeviceType.UNKNOWN
) #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
allow_gpu_fallback: bool = (
Expand Down Expand Up @@ -52,7 +57,6 @@ def __init__(self, *args: Any, **kwargs: Any):
- Device(dla_core=0, allow_gpu_fallback=True)
- Device(gpu_id=1)
"""
print(args, kwargs)
if len(args) == 1:
if not isinstance(args[0], str):
raise TypeError(
Expand All @@ -63,7 +67,7 @@ def __init__(self, *args: Any, **kwargs: Any):
if self.device_type == DeviceType.DLA:
self.dla_core = id
self.gpu_id = 0
logging.warn(
logging.warning(
"Setting GPU id to 0 for device because device 0 manages DLA on AGX Devices",
)
else:
Expand All @@ -76,12 +80,11 @@ def __init__(self, *args: Any, **kwargs: Any):
if "gpu_id" in kwargs:
self.gpu_id = kwargs["gpu_id"]


if self.dla_core >= 0:
self.device_type = DeviceType.DLA
if self.gpu_id != 0:
self.gpu_id = 0
logging.warn(
logging.warning(
"Setting GPU id to 0 for device because device 0 manages DLA on AGX Platforms",
)
else:
Expand All @@ -93,9 +96,7 @@ def __init__(self, *args: Any, **kwargs: Any):

else:
raise ValueError(
"Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments".format(
len(args)
)
f"Unexpected number of positional arguments for class Device \n Found {len(args)} arguments, expected either zero or a single positional arguments"
)

if "allow_gpu_fallback" in kwargs:
Expand All @@ -107,17 +108,20 @@ def __init__(self, *args: Any, **kwargs: Any):
if isinstance(kwargs["device_type"], trt.DeviceType):
self.device_type = DeviceType._from(kwargs["device_type"])


def __str__(self) -> str:
suffix = ")" if self.device_type == DeviceType.GPU else ", dla_core={}, allow_gpu_fallback={})".format(self.dla_core, self.allow_gpu_fallback)
return "Device(type={}, gpu_id={}{}".format(self.device_type, self.gpu_id, suffix)

suffix = (
")"
if self.device_type == DeviceType.GPU
else f", dla_core={self.dla_core}, allow_gpu_fallback={self.allow_gpu_fallback})"
)
dev_str: str = f"Device(type={self.device_type}, gpu_id={self.gpu_id}{suffix}"
return dev_str

def __repr__(self) -> str:
return self.__str__()

@classmethod
def _from(cls, d: Optional[Self | torch.device | str]) -> Self:
def _from(cls, d: Optional[Self | torch.device | str]) -> Device:
"""Cast a device-type to torch_tensorrt.Device
Returns the corresponding torch_tensorrt.Device
Expand All @@ -137,11 +141,11 @@ def _from(cls, d: Optional[Self | torch.device | str]) -> Self:
return cls(d)

@classmethod
def _from_torch_device(cls, torch_dev: torch.device) -> Self:
def _from_torch_device(cls, torch_dev: torch.device) -> Device:
return cls._from(torch_dev)

@classmethod
def _current_device(cls) -> Self:
def _current_device(cls) -> Device:
dev_id = torch.cuda.current_device()
return cls(gpu_id=dev_id)

Expand All @@ -166,19 +170,19 @@ def to(self, t: type) -> torch.device:
raise TypeError("Unsupported target type for device conversion")

def _to_serialized_rt_device(self) -> str:
if ENABLED_FEATURES.torch_tensorrt_runtime:
delim = torch.ops.tensorrt.SERIALIZED_RT_DEVICE_DELIM()[0]
dev_info = torch.cuda.get_device_properties(self.gpu_id)
rt_info = [
self.gpu_id,
dev_info.major,
dev_info.minor,
int(self.device_type.to(trt.DeviceType)),
dev_info.name
]
rt_info = [str(i) for i in rt_info]
packed_rt_info = delim.join(rt_info)
logging.debug("Serialized Device Info: {}".format(packed_rt_info))
return packed_rt_info
else:
if not ENABLED_FEATURES.torch_tensorrt_runtime:
raise NotImplementedError("Torch-TensorRT runtime is not available")

delim = torch.ops.tensorrt.SERIALIZED_RT_DEVICE_DELIM()[0]
dev_info = torch.cuda.get_device_properties(self.gpu_id)
rt_info = [
self.gpu_id,
dev_info.major,
dev_info.minor,
int(self.device_type.to(trt.DeviceType)), # type: ignore[arg-type]
dev_info.name,
]
rt_info = [str(i) for i in rt_info]
packed_rt_info: str = delim.join(rt_info)
logging.debug(f"Serialized Device Info: {packed_rt_info}")
return packed_rt_info
30 changes: 12 additions & 18 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple

import torch
import numpy as np
import tensorrt as trt

from torch_tensorrt._enums import dtype, memory_format


class Input(object):
"""
Defines an input to a module in terms of expected shape, data type and tensor format.
Expand All @@ -30,12 +28,12 @@ class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1

shape_mode: Optional[
_ShapeMode
] = None #: Is input statically or dynamically shaped
shape: Optional[
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
shape_mode: Optional[_ShapeMode] = (
None #: Is input statically or dynamically shaped
)
shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
)
dtype: dtype = (
dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
Expand All @@ -47,7 +45,6 @@ class _ShapeMode(Enum):
DOMAIN_OFFSET: float = 2.0
low_tensor_domain_incl: float = 0.0
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
torch_dtype: torch.dtype = torch.float32
torch_tensor: torch.Tensor = None
name: str = ""

Expand Down Expand Up @@ -153,16 +150,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

else:
raise ValueError(
"Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format(
len(args)
)
f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments"
)

if "dtype" in kwargs:
self.dtype = dtype._from(kwargs["dtype"])
self.torch_dtype = self.dtype.to(torch.dtype, use_default=True)
print(self.dtype)
print(self.torch_dtype)

if self.dtype != dtype.unknown:
self._explicit_set_dtype = True
Expand Down Expand Up @@ -352,7 +344,9 @@ def example_tensor(
)
else:
if isinstance(self.shape, tuple):
return torch.rand(self.shape).to(dtype=self.torch_dtype)
return torch.rand(self.shape).to(
dtype=self.dtype.to(torch.dtype, use_default=True)
)
else:
RuntimeError(
f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})"
Expand All @@ -371,7 +365,7 @@ def example_tensor(

if isinstance(self.shape, dict):
return torch.rand(self.shape[optimization_profile_field]).to(
dtype=self.torch_dtype
dtype=self.dtype.to(torch.dtype, use_default=True)
)
else:
raise RuntimeError(
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,27 +80,30 @@ def _find_lib(name: str, paths: List[str]) -> str:
for lib in LINUX_LIBS:
ctypes.CDLL(_find_lib(lib, LINUX_PATHS))

import logging

import torch
from torch_tensorrt._features import ENABLED_FEATURES, _enabled_features_str

import logging
_LOGGER = logging.getLogger(__name__)
_LOGGER.debug(_enabled_features_str())


def _register_with_torch() -> None:
trtorch_dir = os.path.dirname(__file__)
if os.path.isfile(trtorch_dir + "/lib/libtorchtrt.so"):
assert ENABLED_FEATURES.torchscript_frontend == True
assert ENABLED_FEATURES.torch_tensorrt_runtime== True
assert ENABLED_FEATURES.torch_tensorrt_runtime == True
torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so")
elif os.path.isfile(trtorch_dir + "/lib/libtorchtrt_runtime.so"):
assert ENABLED_FEATURES.torch_tensorrt_runtime == True
torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt_runtime.so")


_register_with_torch()

from torch_tensorrt._enums import dtype, memory_format, DeviceType # noqa: F401
from torch_tensorrt._Device import Device # noqa: F401
from torch_tensorrt._enums import DeviceType, dtype, memory_format # noqa: F401
from torch_tensorrt._Input import Input # noqa: F401
from torch_tensorrt.runtime import * # noqa: F403

Expand All @@ -115,4 +118,3 @@ def _register_with_torch() -> None:
from torch_tensorrt import dynamo # noqa: F401

from torch_tensorrt._compile import * # noqa: F403

Loading

0 comments on commit e9ef3ca

Please sign in to comment.