Skip to content

Commit

Permalink
feat/fix: Working Windows build + partial Dynamo support
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Dec 18, 2023
1 parent 78756c4 commit 49d142a
Show file tree
Hide file tree
Showing 14 changed files with 513 additions and 372 deletions.
84 changes: 0 additions & 84 deletions BUILD

This file was deleted.

26 changes: 0 additions & 26 deletions py/BUILD

This file was deleted.

17 changes: 9 additions & 8 deletions py/torch_tensorrt/_Device.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import sys
from typing import Any, Optional, Tuple

Expand All @@ -11,7 +12,6 @@
# from torch_tensorrt import _enums
import tensorrt as trt
import torch
from torch_tensorrt import logging

try:
from torch_tensorrt import _C
Expand All @@ -21,6 +21,9 @@
)


logger = logging.getLogger(__name__)


class Device(object):
"""
Defines a device that can be used to specify target devices for engines
Expand Down Expand Up @@ -72,8 +75,7 @@ def __init__(self, *args: Any, **kwargs: Any):
else:
self.dla_core = id
self.gpu_id = 0
logging.log(
logging.Level.Warning,
logger.info(
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier",
)

Expand All @@ -86,8 +88,7 @@ def __init__(self, *args: Any, **kwargs: Any):
self.gpu_id = kwargs["gpu_id"]
else:
self.gpu_id = 0
logging.log(
logging.Level.Warning,
logger.info(
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier",
)
else:
Expand Down Expand Up @@ -122,7 +123,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.__str__()

def _to_internal(self) -> _C.Device:
def _to_internal(self) -> "_C.Device":
internal_dev = _C.Device()
if self.device_type == trt.DeviceType.GPU:
internal_dev.device_type = _C.DeviceType.GPU
Expand Down Expand Up @@ -152,8 +153,8 @@ def _from_torch_device(cls, torch_dev: torch.device) -> Self:

@classmethod
def _current_device(cls) -> Self:
dev = _C._get_current_device()
return cls(gpu_id=dev.gpu_id)
# dev = _C._get_current_device()
return cls(gpu_id=torch.cuda.current_device())

@staticmethod
def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]:
Expand Down
18 changes: 9 additions & 9 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ def _supported_input_size_type(input_size: Any) -> bool:
def _parse_dtype(dtype: Any) -> _enums.dtype:
if isinstance(dtype, torch.dtype):
if dtype == torch.long:
return _enums.dtype.long
return _enums.dtype.int64
elif dtype == torch.int32:
return _enums.dtype.int32
elif dtype == torch.half:
return _enums.dtype.half
return _enums.dtype.float16
elif dtype == torch.float:
return _enums.dtype.float
return _enums.dtype.float32
elif dtype == torch.float64:
return _enums.dtype.double
return _enums.dtype.float64
elif dtype == torch.bool:
return _enums.dtype.bool
else:
Expand All @@ -255,24 +255,24 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:

@staticmethod
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
if dtype == _enums.dtype.long:
if dtype == _enums.dtype.int64:
return torch.long
elif dtype == _enums.dtype.int32:
return torch.int32
elif dtype == _enums.dtype.half:
elif dtype == _enums.dtype.float16:
return torch.half
elif dtype == _enums.dtype.float:
elif dtype == _enums.dtype.float32:
return torch.float
elif dtype == _enums.dtype.bool:
return torch.bool
elif dtype == _enums.dtype.double:
elif dtype == _enums.dtype.float64:
return torch.float64
else:
# Default torch_dtype used in FX path
return torch.float32

def is_trt_dtype(self) -> bool:
return bool(self.dtype != _enums.dtype.long)
return bool(self.dtype != _enums.dtype.int64)

@staticmethod
def _parse_format(format: Any) -> _enums.TensorFormat:
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _find_lib(name: str, paths: List[str]) -> str:

def _register_with_torch() -> None:
trtorch_dir = os.path.dirname(__file__)
torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so")
torch.ops.load_library(os.path.join(trtorch_dir, "lib", "libtorchtrt.so"))


_register_with_torch()
# _register_with_torch()
7 changes: 3 additions & 4 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

import torch
import torch.fx
import torch_tensorrt.ts
import torch_tensorrt
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.ts._compiler import compile as torchscript_compile
from typing_extensions import TypeGuard

from packaging import version
Expand Down Expand Up @@ -192,12 +191,12 @@ def compile(
elif target_ir == _IRType.fx:
if (
torch.float16 in enabled_precisions_set
or torch_tensorrt.dtype.half in enabled_precisions_set
or torch_tensorrt.dtype.float16 in enabled_precisions_set
):
lower_precision = LowerPrecision.FP16
elif (
torch.float32 in enabled_precisions_set
or torch_tensorrt.dtype.float in enabled_precisions_set
or torch_tensorrt.dtype.float32 in enabled_precisions_set
):
lower_precision = LowerPrecision.FP32
else:
Expand Down
29 changes: 27 additions & 2 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401
from enum import Enum, auto

from tensorrt import DeviceType # noqa: F401
# from tensorrt import DeviceType # noqa: F401


class dtype(Enum):
float32 = auto()
half = auto()
float16 = auto()
float = auto()
int8 = auto()
int32 = auto()
int = auto()
int64 = auto()
float64 = auto()
bool = auto()
unknown = auto()


class TensorFormat(Enum):
contiguous = auto()
channels_last = auto()


class EngineCapability(Enum):
safe_gpu = auto()
safe_dla = auto()
default = auto()
17 changes: 14 additions & 3 deletions py/torch_tensorrt/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any

import torch
from torch_tensorrt import _C
from torch_tensorrt._version import __version__


Expand All @@ -16,7 +15,14 @@ def get_build_info() -> str:
Returns:
str: String containing the build information for torch_tensorrt distribution
"""
core_build_info = _C.get_build_info()
try:
from torch_tensorrt import _C

core_build_info = _C.get_build_info()
except:
core_build_info = ""
print("Unable to get _C build info, _C extensions unavailable")

build_info = str(
"Torch-TensorRT Version: "
+ str(__version__)
Expand All @@ -30,7 +36,12 @@ def get_build_info() -> str:


def set_device(gpu_id: int) -> None:
_C.set_device(gpu_id)
try:
from torch_tensorrt import _C

_C.set_device(gpu_id)
except:
print("Unable to set_device, _C extensions unavailable")


def sanitized_torch_version() -> Any:
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ def compile(

if (
torch.float16 in enabled_precisions
or torch_tensorrt.dtype.half in enabled_precisions
or torch_tensorrt.dtype.float16 in enabled_precisions
):
precision = torch.float16
elif (
torch.float32 in enabled_precisions
or torch_tensorrt.dtype.float in enabled_precisions
or torch_tensorrt.dtype.float32 in enabled_precisions
):
precision = torch.float32
elif len(enabled_precisions) == 0:
Expand Down
Loading

0 comments on commit 49d142a

Please sign in to comment.