Skip to content

pytorch/TensorRT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

7408664 · Dec 17, 2024
Dec 12, 2024
Aug 9, 2022
Jul 8, 2024
Dec 16, 2024
Sep 23, 2024
Dec 12, 2024
Dec 16, 2024
Dec 16, 2024
Dec 16, 2024
Aug 21, 2024
Dec 12, 2024
Dec 17, 2024
Dec 16, 2024
Jun 14, 2024
Dec 12, 2024
Aug 29, 2024
Jul 31, 2024
Jul 31, 2024
Oct 31, 2020
Aug 9, 2022
Aug 9, 2022
Aug 9, 2022
Aug 16, 2024
Sep 13, 2022
Dec 16, 2024
Aug 9, 2022
May 1, 2024
Aug 9, 2022
Jun 20, 2023
Jul 8, 2024
Jul 8, 2024
Jul 26, 2022
Jun 4, 2022
Dec 16, 2024
Dec 12, 2024
Dec 12, 2024
Dec 13, 2024
Feb 24, 2023
Dec 16, 2024
Aug 15, 2024
Dec 17, 2024
Dec 16, 2024
Sep 23, 2024
Jun 7, 2024

Repository files navigation

Torch-TensorRT

Easily achieve the best inference performance for any PyTorch model on the NVIDIA platform.

Documentation pytorch cuda trt license linux_tests windows_tests


Torch-TensorRT brings the power of TensorRT to PyTorch. Accelerate inference latency by up to 5x compared to eager execution in just one line of code.

Installation

Stable versions of Torch-TensorRT are published on PyPI

pip install torch-tensorrt

Nightly versions of Torch-TensorRT are published on the PyTorch package index

pip install --pre torch-tensorrt --index-url https://download.pytorch.org/whl/nightly/cu124

Torch-TensorRT is also distributed in the ready-to-run NVIDIA NGC PyTorch Container which has all dependencies with the proper versions and example notebooks included.

For more advanced installation methods, please see here

Quickstart

Option 1: torch.compile

You can use Torch-TensorRT anywhere you use torch.compile:

import torch
import torch_tensorrt

model = MyModel().eval().cuda() # define your model here
x = torch.randn((1, 3, 224, 224)).cuda() # define what the inputs to the model will look like

optimized_model = torch.compile(model, backend="tensorrt")
optimized_model(x) # compiled on first run

optimized_model(x) # this will be fast!

Option 2: Export

If you want to optimize your model ahead-of-time and/or deploy in a C++ environment, Torch-TensorRT provides an export-style workflow that serializes an optimized module. This module can be deployed in PyTorch or with libtorch (i.e. without a Python dependency).

Step 1: Optimize + serialize

import torch
import torch_tensorrt

model = MyModel().eval().cuda() # define your model here
inputs = [torch.randn((1, 3, 224, 224)).cuda()] # define a list of representative inputs here

trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs) # PyTorch only supports Python runtime for an ExportedProgram. For C++ deployment, use a TorchScript file
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)

Step 2: Deploy

Deployment in PyTorch:
import torch
import torch_tensorrt

inputs = [torch.randn((1, 3, 224, 224)).cuda()] # your inputs go here

# You can run this in a new python session!
model = torch.export.load("trt.ep").module()
# model = torch_tensorrt.load("trt.ep").module() # this also works
model(*inputs)
Deployment in C++:
#include "torch/script.h"
#include "torch_tensorrt/torch_tensorrt.h"

auto trt_mod = torch::jit::load("trt.ts");
auto input_tensor = [...]; // fill this with your inputs
auto results = trt_mod.forward({input_tensor});

Further resources

Platform Support

Platform Support
Linux AMD64 / GPU Supported
Windows / GPU Supported (Dynamo only)
Linux aarch64 / GPU Native Compilation Supported on JetPack-4.4+ (use v1.0.0 for the time being)
Linux aarch64 / DLA Native Compilation Supported on JetPack-4.4+ (use v1.0.0 for the time being)
Linux ppc64le / GPU Not supported

Note: Refer NVIDIA L4T PyTorch NGC container for PyTorch libraries on JetPack.

Dependencies

These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.

  • Bazel 6.3.2
  • Libtorch 2.5.0.dev (latest nightly) (built with CUDA 12.4)
  • CUDA 12.4
  • TensorRT 10.6.0.26

Deprecation Policy

Deprecation is used to inform developers that some APIs and tools are no longer recommended for use. Beginning with version 2.3, Torch-TensorRT has the following deprecation policy:

Deprecation notices are communicated in the Release Notes. Deprecated API functions will have a statement in the source documenting when they were deprecated. Deprecated methods and classes will issue deprecation warnings at runtime, if they are used. Torch-TensorRT provides a 6-month migration period after the deprecation. APIs and tools continue to work during the migration period. After the migration period ends, APIs and tools are removed in a manner consistent with semantic versioning.

Contributing

Take a look at the CONTRIBUTING.md

License

The Torch-TensorRT license can be found in the LICENSE file. It is licensed with a BSD Style licence