Skip to content

Commit

Permalink
Add support for Torchvision ops (triton-inference-server#1750)
Browse files Browse the repository at this point in the history
* Add support for Torchvision ops

* review edits

* add test for torchvision op model

* additional review edits

* update docs for torchvision ops
  • Loading branch information
CoderHam authored and deadeyegoodwin committed Jul 31, 2020
1 parent 810a807 commit 88bb843
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 5 deletions.
14 changes: 12 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ ARG TRITON_CONTAINER_VERSION=20.08dev

# libgoogle-glog0v5 is needed by caffe2 libraries.
# libcurl4-openSSL-dev is needed for GCS
# python3-dev is needed by Torchvision
RUN apt-get update && \
apt-get install -y --no-install-recommends \
autoconf \
Expand All @@ -176,6 +177,7 @@ RUN apt-get update && \
rapidjson-dev \
libb64-dev \
patchelf \
python3-dev \
software-properties-common && \
if [ $(cat /etc/os-release | grep 'VERSION_ID="16.04"' | wc -l) -ne 0 ]; then \
apt-get install -y --no-install-recommends \
Expand Down Expand Up @@ -224,7 +226,7 @@ COPY --from=tritonserver_pytorch /opt/conda/lib/libmkl_intel_lp64.so /opt/triton
COPY --from=tritonserver_pytorch /opt/conda/lib/libmkl_rt.so /opt/tritonserver/lib/pytorch/
COPY --from=tritonserver_pytorch /opt/conda/lib/libmkl_vml_def.so /opt/tritonserver/lib/pytorch/

# LibTorch headers and libraries
# LibTorch and Torchvision headers and libraries
COPY --from=tritonserver_pytorch /opt/conda/lib/python3.6/site-packages/torch/include \
/opt/tritonserver/include/torch
COPY --from=tritonserver_pytorch /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch.so \
Expand All @@ -235,6 +237,10 @@ COPY --from=tritonserver_pytorch /opt/conda/lib/python3.6/site-packages/torch/li
/opt/tritonserver/lib/pytorch/
COPY --from=tritonserver_pytorch /opt/conda/lib/python3.6/site-packages/torch/lib/libcaffe2_nvrtc.so \
/opt/tritonserver/lib/pytorch/
COPY --from=tritonserver_pytorch /opt/pytorch/vision/torchvision/csrc \
/opt/tritonserver/include/torchvision/torchvision/
COPY --from=tritonserver_pytorch /opt/pytorch/vision/build/libtorchvision.so \
/opt/tritonserver/lib/pytorch/
RUN cd /opt/tritonserver/lib/pytorch && \
for i in `find . -mindepth 1 -maxdepth 1 -type f -name '*\.so*'`; do \
patchelf --set-rpath '$ORIGIN' $i; \
Expand Down Expand Up @@ -330,7 +336,7 @@ RUN LIBCUDA_FOUND=$(ldconfig -p | grep -v compat | awk '{print $1}' | grep libcu
-DTRITON_ENABLE_PYTORCH=ON \
-DTRITON_ENABLE_ENSEMBLE=ON \
-DTRITON_ONNXRUNTIME_INCLUDE_PATHS="/opt/tritonserver/include/onnxruntime" \
-DTRITON_PYTORCH_INCLUDE_PATHS="/opt/tritonserver/include/torch" \
-DTRITON_PYTORCH_INCLUDE_PATHS="/opt/tritonserver/include/torch;/opt/tritonserver/include/torch/torch/csrc/api/include;/opt/tritonserver/include/torchvision;/usr/include/python3.6" \
-DTRITON_EXTRA_LIB_PATHS="/opt/tritonserver/lib;/opt/tritonserver/lib/tensorflow;/opt/tritonserver/lib/pytorch;/opt/tritonserver/lib/onnx" \
../build && \
make -j16 server && \
Expand Down Expand Up @@ -371,6 +377,10 @@ LABEL com.nvidia.tritonserver.version="${TRITON_SERVER_VERSION}"

ENV PATH /opt/tritonserver/bin:${PATH}

# Need to include pytorch in LD_LIBRARY_PATH since Torchvision loads custom
# ops from that path
ENV LD_LIBRARY_PATH /opt/tritonserver/lib/pytorch/:$LD_LIBRARY_PATH

ENV TF_ADJUST_HUE_FUSED 1
ENV TF_ADJUST_SATURATION_FUSED 1
ENV TF_ENABLE_WINOGRAD_NONFUSED 1
Expand Down
1 change: 1 addition & 0 deletions Dockerfile.QA
Original file line number Diff line number Diff line change
Expand Up @@ -328,5 +328,6 @@ RUN rm -fr qa/L0_copyrights qa/L0_build_variants && \
pip3 install --upgrade qa/pkgs/triton*.whl

ENV LD_LIBRARY_PATH /opt/tritonserver/qa/clients:${LD_LIBRARY_PATH}

ENV GOPATH /root/go
ENV PATH $PATH:$GOPATH/bin
4 changes: 2 additions & 2 deletions docs/build.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ paths can be specified by separating them with a semicolon, for
example, -DTRITON_EXTRA_LIB_PATHS="/path/a;/path/b".

For the PyTorch backend you must also provide the path to the PyTorch
headers using the -DTRITON_PYTORCH_INCLUDE_PATHS option. Multiple paths
can be specified by separating them with a semicolon.
and Torchvision headers using the -DTRITON_PYTORCH_INCLUDE_PATHS option.
Multiple paths can be specified by separating them with a semicolon.

Configure Triton Build
......................
Expand Down
4 changes: 4 additions & 0 deletions docs/custom_operation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ seriously, if there are custom layer name conflicts across multiple
shared libraries or the handles used to register them in PyTorch there
is currently no way to handle it.

Starting with the 20.07 release of Triton the `TorchVision operations<https://github.com/pytorch/vision>`_
will be included with the PyTorch backend and hence they do not have
to be explicitly added as custom operations.

When building the custom operations shared library it is important to
use the same version of PyTorch as is being used in Triton. You can
find the PyTorch version in the `Triton Release Notes
Expand Down
8 changes: 8 additions & 0 deletions qa/L0_custom_ops/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ CLIENT_LOG="./client.log"
ZERO_OUT_TEST=zero_out_test.py
CUDA_OP_TEST=cuda_op_test.py
MOD_OP_TEST=mod_op_test.py
VISION_OP_TEST=vision_op_test.py
ONNX_OP_TEST=onnx_op_test.py

SERVER=/opt/tritonserver/bin/tritonserver
Expand Down Expand Up @@ -120,6 +121,13 @@ if [ $? -ne 0 ]; then
RET=1
fi

python $VISION_OP_TEST -v -m libtorch_visionop >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
RET=1
fi

set -e

if [ $RET -eq 0 ]; then
Expand Down
83 changes: 83 additions & 0 deletions qa/L0_custom_ops/vision_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import argparse
import numpy as np
import os
import sys
from builtins import range
import tritongrpcclient as grpcclient
import tritonhttpclient as httpclient
from tritonclientutils import np_to_triton_dtype

FLAGS = None

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False,
help='Enable verbose output')
parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8000',
help='Inference server URL. Default is localhost:8000.')
parser.add_argument('-i', '--protocol', type=str, required=False, default='http',
help='Protocol ("http"/"grpc") used to ' +
'communicate with inference service. Default is "http".')
parser.add_argument('-m', '--model', type=str, required=True,
help='Name of model.')

FLAGS = parser.parse_args()
if (FLAGS.protocol != "http") and (FLAGS.protocol != "grpc"):
print("unexpected protocol \"{}\", expects \"http\" or \"grpc\"".format(FLAGS.protocol))
exit(1)

client_util = httpclient if FLAGS.protocol == "http" else grpcclient

# Run the libtorch_visionop model, which depends on a torchvision custom operation
model_name = FLAGS.model

# Create the inference context for the model.
client = client_util.InferenceServerClient(FLAGS.url, verbose=FLAGS.verbose)

# Create the data for one input tensor.
input_data = np.random.rand(1, 16, 10, 10).astype(np.float32)

inputs = []
inputs.append(client_util.InferInput(
"INPUT__0", input_data.shape, np_to_triton_dtype(input_data.dtype)))
inputs[0].set_data_from_numpy(input_data)

results = client.infer(model_name, inputs)

# We expect 1 result of shape [1, 33, 12, 14].
output_data = results.as_numpy('OUTPUT__0')
if output_data is None:
print("error: expected 'OUTPUT__0'")
sys.exit(1)

if (output_data.shape != (1, 33, 12, 14)):
print("error: incorrect shape "+ str(output_data.shape) +"for 'OUTPUT__0'")
sys.exit(1)
57 changes: 57 additions & 0 deletions qa/common/gen_qa_custom_ops_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,62 @@ def create_moduloop_modelconfig(models_dir, model_version):
with open(config_dir + "/config.pbtxt", "w") as cfile:
cfile.write(config)

# Use Torchvision ops
def create_visionop_modelfile(models_dir, model_version):
model_name = "libtorch_visionop"

class CustomVisionNet(nn.Module):
def __init__(self):
super(CustomVisionNet, self).__init__()
self.conv2 = ops.misc.ConvTranspose2d(16, 33, (3, 5))
def forward(self, input0):
return self.conv2(input0)

moduloCustomModel = CustomVisionNet()
example_input0 = torch.rand(1, 16, 10, 10, dtype=torch.float32)
traced = torch.jit.trace(moduloCustomModel, (example_input0,))

model_version_dir = models_dir + "/" + \
model_name + "/" + str(model_version)

try:
os.makedirs(model_version_dir)
except OSError as ex:
pass # ignore existing dir

traced.save(model_version_dir + "/model.pt")


def create_visionop_modelconfig(models_dir, model_version):
model_name = "libtorch_visionop"
config_dir = models_dir + "/" + model_name
config = '''
name: "{}"
platform: "pytorch_libtorch"
max_batch_size: 0
input [
{{
name: "INPUT__0"
data_type: TYPE_FP32
dims: [ 1, 16, 10, 10 ]
}}
]
output [
{{
name: "OUTPUT__0"
data_type: TYPE_FP32
dims: [1, 33, 12, 14]
}}
]
'''.format(model_name)

try:
os.makedirs(config_dir)
except OSError as ex:
pass # ignore existing dir

with open(config_dir + "/config.pbtxt", "w") as cfile:
cfile.write(config)

def create_zero_out_models(models_dir):
model_version = 1
Expand Down Expand Up @@ -436,5 +492,6 @@ def create_modulo_op_models(models_dir):
if FLAGS.libtorch:
import torch
from torch import nn
from torchvision import ops
import torch.utils.cpp_extension
create_modulo_op_models(FLAGS.models_dir)
4 changes: 4 additions & 0 deletions src/backends/pytorch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ add_library(
${LIBTORCH_SRCS} ${LIBTORCH_HDRS}
)
set_target_properties(libtorch-backend-library PROPERTIES CXX_STANDARD 14)

# Need to turn unused-but-set-variable off due to Torchvision
set_target_properties(libtorch-backend-library PROPERTIES COMPILE_FLAGS -Wno-unused-but-set-variable)

add_dependencies(libtorch-backend-library proto-library)
target_include_directories(libtorch-backend-library PRIVATE ${TRITON_PYTORCH_INCLUDE_PATHS})

Expand Down
3 changes: 2 additions & 1 deletion src/backends/pytorch/libtorch_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once

#include <torch/script.h> // One-stop header.
#include <torch/script.h> // One-stop header for TorchScript
#include <torchvision/vision.h> // Torchvision header
#include <set>
#include <string>
#include <unordered_map>
Expand Down
7 changes: 7 additions & 0 deletions src/backends/pytorch/libtorch_backend_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@

#include "src/backends/pytorch/libtorch_backend_factory.h"

#include <torchvision/DeformConv.h>
#include <torchvision/PSROIAlign.h>
#include <torchvision/PSROIPool.h>
#include <torchvision/ROIAlign.h>
#include <torchvision/ROIPool.h>
#include <torchvision/empty_tensor_op.h>
#include <torchvision/nms.h>
#include <memory>
#include <string>
#include <vector>
Expand Down
1 change: 1 addition & 0 deletions src/servers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ if(${TRITON_ENABLE_CAFFE2} OR ${TRITON_ENABLE_PYTORCH})
PUBLIC -ltorch
PUBLIC -ltorch_cpu
PUBLIC -ltorch_cuda
PUBLIC -ltorchvision
PUBLIC -lcaffe2_detectron_ops_gpu
PUBLIC -lcaffe2_nvrtc
PUBLIC -lc10
Expand Down

0 comments on commit 88bb843

Please sign in to comment.