Skip to content

Commit

Permalink
[CustomOps] TensorRT Gather Topk Ops (open-mmlab#1033)
Browse files Browse the repository at this point in the history
* add gather topk

* add shape inference and document

* fix faster rcnn

* reshape topk

* fix
  • Loading branch information
q.yao authored Sep 19, 2022
1 parent 50bd6b1 commit 0caeaf2
Show file tree
Hide file tree
Showing 11 changed files with 476 additions and 20 deletions.
150 changes: 150 additions & 0 deletions csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "gather_topk.hpp"

#include <assert.h>
#include <stdio.h>

#include <chrono>

#include "NvInferVersion.h"
#include "gather_topk_kernel.hpp"
#include "trt_serialize.hpp"

namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"GatherTopk"};
} // namespace

GatherTopk::GatherTopk(const std::string &name) : TRTPluginBase(name) {}

GatherTopk::GatherTopk(const std::string name, const void *data, size_t length)
: TRTPluginBase(name) {}

nvinfer1::IPluginV2DynamicExt *GatherTopk::clone() const TRT_NOEXCEPT {
GatherTopk *plugin = new GatherTopk(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs GatherTopk::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
assert(inputs[0].nbDims >= inputs[1].nbDims);
nvinfer1::DimsExprs ret;
ret.nbDims = inputs[0].nbDims;
for (int i = 0; i < inputs[1].nbDims; ++i) {
ret.d[i] = inputs[1].d[i];
}
for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) {
ret.d[i] = inputs[0].d[i];
}
return ret;
}

bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
switch (pos) {
case 0:
// data
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
(ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
case 1:
// indices
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
case 2:
// output
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
default:
return true;
}
return true;
}

void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {}

size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}

int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
const int *dims = &(inputDesc[0].dims.d[0]);
const int *indices_dims = &(inputDesc[1].dims.d[0]);
int nbDims = inputDesc[0].dims.nbDims;
int indice_nbDims = inputDesc[1].dims.nbDims;

const void *data = inputs[0];
const void *indices = inputs[1];
void *output = outputs[0];

auto data_type = inputDesc[0].type;

switch (data_type) {
case nvinfer1::DataType::kFLOAT:
gather_topk_impl<float>((float *)data, (int *)indices, dims, nbDims, indices_dims,
indice_nbDims, (float *)output, stream);
break;

case nvinfer1::DataType::kINT32:
gather_topk_impl<int>((int *)data, (int *)indices, dims, nbDims, indices_dims, indice_nbDims,
(int *)output, stream);
break;
default:
break;
}

return 0;
}

nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

// IPluginV2 Methods
const char *GatherTopk::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *GatherTopk::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }

int GatherTopk::getNbOutputs() const TRT_NOEXCEPT { return 1; }

size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT { return 0; }

void GatherTopk::serialize(void *buffer) const TRT_NOEXCEPT {}

GatherTopkCreator::GatherTopkCreator() {
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *GatherTopkCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }

nvinfer1::IPluginV2 *GatherTopkCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
auto *plugin = new GatherTopk(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *GatherTopkCreator::deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new GatherTopk(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

REGISTER_TENSORRT_PLUGIN(GatherTopkCreator);
} // namespace mmdeploy
64 changes: 64 additions & 0 deletions csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCATTERND_HPP
#define TRT_SCATTERND_HPP
#include <cublas_v2.h>

#include <memory>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

namespace mmdeploy {
class GatherTopk : public TRTPluginBase {
public:
GatherTopk(const std::string &name);

GatherTopk(const std::string name, const void *data, size_t length);

GatherTopk() = delete;

// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;

// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;

// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
};

class GatherTopkCreator : public TRTPluginCreatorBase {
public:
GatherTopkCreator();

const char *getPluginName() const TRT_NOEXCEPT override;

const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_SCATTERND_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include <functional>
#include <numeric>
#include <vector>

#include "common_cuda_helper.hpp"
#include "gather_topk_kernel.hpp"
#include "trt_plugin_helper.hpp"

template <typename scalar_t>
__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output,
int batch, int num_input, int num_indices, int channel) {
CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) {
const int b_id = index / (num_indices * channel);
const int n_id = (index / channel) % num_indices;
const int c_id = index % channel;

const int input_n_id = indices[b_id * num_indices + n_id];
const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id];
output[b_id * num_indices * channel + n_id * channel + c_id] = value;
}
}

template <typename scalar_t>
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
const int* indices_dims, int indice_nbDims, scalar_t* output,
cudaStream_t stream) {
int batch = 1;
for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i];
int num_input = dims[indice_nbDims - 1];
int num_indices = indices_dims[indice_nbDims - 1];
int channel = 1;
for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i];
const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK);
gather_topk_kernel<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(input, indices, output, batch,
num_input, num_indices, channel);
}

template void gather_topk_impl<float>(const float* input, const int* indices, const int* dims,
int nbDims, const int* indices_dims, int indice_nbDims,
float* output, cudaStream_t stream);

template void gather_topk_impl<int32_t>(const int32_t* input, const int* indices, const int* dims,
int nbDims, const int* indices_dims, int indice_nbDims,
int32_t* output, cudaStream_t stream);
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_GRID_SAMPLER_KERNEL_HPP
#define TRT_GRID_SAMPLER_KERNEL_HPP
#include <cuda_runtime.h>

template <typename scalar_t>
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
const int* indices_dims, int indice_nbDims, scalar_t* output,
cudaStream_t stream);
#endif // TRT_GRID_SAMPLER_KERNEL_HPP
42 changes: 42 additions & 0 deletions docs/en/06-custom-ops/tensorrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@
- [Inputs](#inputs-9)
- [Outputs](#outputs-9)
- [Type Constraints](#type-constraints-9)
- [GatherTopk](#gathertopk)
- [Description](#description-10)
- [Parameters](#parameters-10)
- [Inputs](#inputs-10)
- [Outputs](#outputs-10)
- [Type Constraints](#type-constraints-10)

<!-- TOC -->

Expand Down Expand Up @@ -447,3 +453,39 @@ None
#### Type Constraints

- T:tensor(float32, Linear)

### GatherTopk

#### Description

TensorRT 8.2~8.4 would give unexpected result for multi-index gather.

```python
data[batch_index, bbox_index, ...]
```

Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.

#### Parameters

None

#### Inputs

<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>

<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>

#### Outputs

<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
</dl>

#### Type Constraints

- T:tensor(float32, Linear), tensor(int32, Linear)
42 changes: 42 additions & 0 deletions docs/zh_cn/06-custom-ops/tensorrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@
- [Inputs](#inputs-9)
- [Outputs](#outputs-9)
- [Type Constraints](#type-constraints-9)
- [GatherTopk](#gathertopk)
- [Description](#description-10)
- [Parameters](#parameters-10)
- [Inputs](#inputs-10)
- [Outputs](#outputs-10)
- [Type Constraints](#type-constraints-10)

<!-- TOC -->

Expand Down Expand Up @@ -447,3 +453,39 @@ None
#### Type Constraints

- T:tensor(float32, Linear)

### GatherTopk

#### Description

TensorRT 8.2~8.4 would give unexpected result for multi-index gather.

```python
data[batch_index, bbox_index, ...]
```

Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.

#### Parameters

None

#### Inputs

<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>

<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>

#### Outputs

<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
</dl>

#### Type Constraints

- T:tensor(float32, Linear), tensor(int32, Linear)
7 changes: 7 additions & 0 deletions mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ def multiclass_nms_static(ctx,
pre_top_k, keep_top_k, iou_threshold,
score_threshold, -1)

# retain shape info
batch_size = boxes.size(0)

dets_shape = dets.shape
label_shape = labels.shape
dets = dets.reshape([batch_size, *dets_shape[1:]])
labels = labels.reshape([batch_size, *label_shape[1:]])
return dets, labels


Expand Down
Loading

0 comments on commit 0caeaf2

Please sign in to comment.