forked from open-mmlab/mmdeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CustomOps] TensorRT Gather Topk Ops (open-mmlab#1033)
* add gather topk * add shape inference and document * fix faster rcnn * reshape topk * fix
- Loading branch information
q.yao
authored
Sep 19, 2022
1 parent
50bd6b1
commit 0caeaf2
Showing
11 changed files
with
476 additions
and
20 deletions.
There are no files selected for viewing
150 changes: 150 additions & 0 deletions
150
csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
#include "gather_topk.hpp" | ||
|
||
#include <assert.h> | ||
#include <stdio.h> | ||
|
||
#include <chrono> | ||
|
||
#include "NvInferVersion.h" | ||
#include "gather_topk_kernel.hpp" | ||
#include "trt_serialize.hpp" | ||
|
||
namespace mmdeploy { | ||
namespace { | ||
static const char *PLUGIN_VERSION{"1"}; | ||
static const char *PLUGIN_NAME{"GatherTopk"}; | ||
} // namespace | ||
|
||
GatherTopk::GatherTopk(const std::string &name) : TRTPluginBase(name) {} | ||
|
||
GatherTopk::GatherTopk(const std::string name, const void *data, size_t length) | ||
: TRTPluginBase(name) {} | ||
|
||
nvinfer1::IPluginV2DynamicExt *GatherTopk::clone() const TRT_NOEXCEPT { | ||
GatherTopk *plugin = new GatherTopk(mLayerName); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
|
||
return plugin; | ||
} | ||
|
||
nvinfer1::DimsExprs GatherTopk::getOutputDimensions( | ||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, | ||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { | ||
assert(inputs[0].nbDims >= inputs[1].nbDims); | ||
nvinfer1::DimsExprs ret; | ||
ret.nbDims = inputs[0].nbDims; | ||
for (int i = 0; i < inputs[1].nbDims; ++i) { | ||
ret.d[i] = inputs[1].d[i]; | ||
} | ||
for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) { | ||
ret.d[i] = inputs[0].d[i]; | ||
} | ||
return ret; | ||
} | ||
|
||
bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, | ||
int nbInputs, int nbOutputs) TRT_NOEXCEPT { | ||
switch (pos) { | ||
case 0: | ||
// data | ||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && | ||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || | ||
(ioDesc[pos].type == nvinfer1::DataType::kINT32 && | ||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); | ||
case 1: | ||
// indices | ||
return ioDesc[pos].type == nvinfer1::DataType::kINT32 && | ||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; | ||
case 2: | ||
// output | ||
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; | ||
default: | ||
return true; | ||
} | ||
return true; | ||
} | ||
|
||
void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc *outputs, | ||
int nbOutputs) TRT_NOEXCEPT {} | ||
|
||
size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, | ||
const nvinfer1::PluginTensorDesc *outputs, | ||
int nbOutputs) const TRT_NOEXCEPT { | ||
return 0; | ||
} | ||
|
||
int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, | ||
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { | ||
const int *dims = &(inputDesc[0].dims.d[0]); | ||
const int *indices_dims = &(inputDesc[1].dims.d[0]); | ||
int nbDims = inputDesc[0].dims.nbDims; | ||
int indice_nbDims = inputDesc[1].dims.nbDims; | ||
|
||
const void *data = inputs[0]; | ||
const void *indices = inputs[1]; | ||
void *output = outputs[0]; | ||
|
||
auto data_type = inputDesc[0].type; | ||
|
||
switch (data_type) { | ||
case nvinfer1::DataType::kFLOAT: | ||
gather_topk_impl<float>((float *)data, (int *)indices, dims, nbDims, indices_dims, | ||
indice_nbDims, (float *)output, stream); | ||
break; | ||
|
||
case nvinfer1::DataType::kINT32: | ||
gather_topk_impl<int>((int *)data, (int *)indices, dims, nbDims, indices_dims, indice_nbDims, | ||
(int *)output, stream); | ||
break; | ||
default: | ||
break; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, | ||
int nbInputs) const TRT_NOEXCEPT { | ||
return inputTypes[0]; | ||
} | ||
|
||
// IPluginV2 Methods | ||
const char *GatherTopk::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } | ||
|
||
const char *GatherTopk::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } | ||
|
||
int GatherTopk::getNbOutputs() const TRT_NOEXCEPT { return 1; } | ||
|
||
size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT { return 0; } | ||
|
||
void GatherTopk::serialize(void *buffer) const TRT_NOEXCEPT {} | ||
|
||
GatherTopkCreator::GatherTopkCreator() { | ||
mPluginAttributes.clear(); | ||
mFC.nbFields = mPluginAttributes.size(); | ||
mFC.fields = mPluginAttributes.data(); | ||
} | ||
|
||
const char *GatherTopkCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } | ||
|
||
const char *GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } | ||
|
||
nvinfer1::IPluginV2 *GatherTopkCreator::createPlugin( | ||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { | ||
auto *plugin = new GatherTopk(name); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
|
||
nvinfer1::IPluginV2 *GatherTopkCreator::deserializePlugin(const char *name, const void *serialData, | ||
size_t serialLength) TRT_NOEXCEPT { | ||
auto plugin = new GatherTopk(name, serialData, serialLength); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
|
||
REGISTER_TENSORRT_PLUGIN(GatherTopkCreator); | ||
} // namespace mmdeploy |
64 changes: 64 additions & 0 deletions
64
csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
#ifndef TRT_SCATTERND_HPP | ||
#define TRT_SCATTERND_HPP | ||
#include <cublas_v2.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "trt_plugin_base.hpp" | ||
|
||
namespace mmdeploy { | ||
class GatherTopk : public TRTPluginBase { | ||
public: | ||
GatherTopk(const std::string &name); | ||
|
||
GatherTopk(const std::string name, const void *data, size_t length); | ||
|
||
GatherTopk() = delete; | ||
|
||
// IPluginV2DynamicExt Methods | ||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; | ||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, | ||
int nbInputs, nvinfer1::IExprBuilder &exprBuilder) | ||
TRT_NOEXCEPT override; | ||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, | ||
int nbOutputs) TRT_NOEXCEPT override; | ||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc *out, | ||
int nbOutputs) TRT_NOEXCEPT override; | ||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, | ||
const nvinfer1::PluginTensorDesc *outputs, | ||
int nbOutputs) const TRT_NOEXCEPT override; | ||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, | ||
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; | ||
|
||
// IPluginV2Ext Methods | ||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, | ||
int nbInputs) const TRT_NOEXCEPT override; | ||
|
||
// IPluginV2 Methods | ||
const char *getPluginType() const TRT_NOEXCEPT override; | ||
const char *getPluginVersion() const TRT_NOEXCEPT override; | ||
int getNbOutputs() const TRT_NOEXCEPT override; | ||
size_t getSerializationSize() const TRT_NOEXCEPT override; | ||
void serialize(void *buffer) const TRT_NOEXCEPT override; | ||
}; | ||
|
||
class GatherTopkCreator : public TRTPluginCreatorBase { | ||
public: | ||
GatherTopkCreator(); | ||
|
||
const char *getPluginName() const TRT_NOEXCEPT override; | ||
|
||
const char *getPluginVersion() const TRT_NOEXCEPT override; | ||
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) | ||
TRT_NOEXCEPT override; | ||
|
||
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, | ||
size_t serialLength) TRT_NOEXCEPT override; | ||
}; | ||
} // namespace mmdeploy | ||
#endif // TRT_SCATTERND_HPP |
46 changes: 46 additions & 0 deletions
46
csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
#include <functional> | ||
#include <numeric> | ||
#include <vector> | ||
|
||
#include "common_cuda_helper.hpp" | ||
#include "gather_topk_kernel.hpp" | ||
#include "trt_plugin_helper.hpp" | ||
|
||
template <typename scalar_t> | ||
__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output, | ||
int batch, int num_input, int num_indices, int channel) { | ||
CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) { | ||
const int b_id = index / (num_indices * channel); | ||
const int n_id = (index / channel) % num_indices; | ||
const int c_id = index % channel; | ||
|
||
const int input_n_id = indices[b_id * num_indices + n_id]; | ||
const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id]; | ||
output[b_id * num_indices * channel + n_id * channel + c_id] = value; | ||
} | ||
} | ||
|
||
template <typename scalar_t> | ||
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims, | ||
const int* indices_dims, int indice_nbDims, scalar_t* output, | ||
cudaStream_t stream) { | ||
int batch = 1; | ||
for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i]; | ||
int num_input = dims[indice_nbDims - 1]; | ||
int num_indices = indices_dims[indice_nbDims - 1]; | ||
int channel = 1; | ||
for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i]; | ||
const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK); | ||
gather_topk_kernel<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(input, indices, output, batch, | ||
num_input, num_indices, channel); | ||
} | ||
|
||
template void gather_topk_impl<float>(const float* input, const int* indices, const int* dims, | ||
int nbDims, const int* indices_dims, int indice_nbDims, | ||
float* output, cudaStream_t stream); | ||
|
||
template void gather_topk_impl<int32_t>(const int32_t* input, const int* indices, const int* dims, | ||
int nbDims, const int* indices_dims, int indice_nbDims, | ||
int32_t* output, cudaStream_t stream); |
10 changes: 10 additions & 0 deletions
10
csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
#ifndef TRT_GRID_SAMPLER_KERNEL_HPP | ||
#define TRT_GRID_SAMPLER_KERNEL_HPP | ||
#include <cuda_runtime.h> | ||
|
||
template <typename scalar_t> | ||
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims, | ||
const int* indices_dims, int indice_nbDims, scalar_t* output, | ||
cudaStream_t stream); | ||
#endif // TRT_GRID_SAMPLER_KERNEL_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.