Skip to content

Commit

Permalink
add unfold
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Aug 15, 2021
1 parent 6d18b74 commit 8cf19f2
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 2 deletions.
3 changes: 3 additions & 0 deletions torch2trt_dynamic/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@
from .nms import convert_nms
from .roi_align import convert_roi_align, convert_RoiAlign
from .roi_pool import convert_roi_pool, convert_RoIPool
from .unfold import convert_unfold

# adaptive_avg_pool1d
__all__ += ['convert_adaptive_avg_pool1d']
Expand Down Expand Up @@ -391,5 +392,7 @@
__all__ += ['convert_roi_align', 'convert_RoiAlign']
# roi_pool
__all__ += ['convert_roi_pool', 'convert_RoIPool']
# unfold
__all__ += ['convert_unfold']
except Exception:
print('plugin not found.')
1 change: 0 additions & 1 deletion torch2trt_dynamic/converters/roll.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import tensorrt as trt
import torch

from torch2trt_dynamic.torch2trt_dynamic import (get_arg, slice_shape_trt,
tensor_trt_get_shape_trt,
tensorrt_converter, trt_)
Expand Down
25 changes: 25 additions & 0 deletions torch2trt_dynamic/converters/unfold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torch2trt_dynamic.plugins import create_torchunfold_plugin
from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter,
trt_)


@tensorrt_converter('torch.nn.functional.unfold')
def convert_unfold(ctx):
input = ctx.method_args[0]
kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=0)
dilation = get_arg(ctx, 'dilation', pos=2, default=1)
padding = get_arg(ctx, 'padding', pos=3, default=0)
stride = get_arg(ctx, 'stride', pos=4, default=1)
output = ctx.method_return
input_trt = trt_(ctx.network, input)

plugin = create_torchunfold_plugin(
'unfold_' + str(id(input)),
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)

layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin)

output._trt = layer.get_output(0)
4 changes: 3 additions & 1 deletion torch2trt_dynamic/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .create_torchembedding_plugin import create_torchembedding_plugin
from .create_torchflip_plugin import create_torchflip_plugin
from .create_torchgather_plugin import create_torchgather_plugin
from .create_torchunfold_plugin import create_torchunfold_plugin
from .globals import load_plugin_library

__all__ = [
Expand All @@ -25,7 +26,8 @@
'create_torchflip_plugin', 'create_torchcummaxmin_plugin',
'create_torchcum_plugin', 'create_dcn_plugin', 'create_nms_plugin',
'create_roiextractor_plugin', 'create_roipool_plugin',
'create_torchembedding_plugin', 'create_torchbmm_plugin'
'create_torchembedding_plugin', 'create_torchbmm_plugin',
'create_torchunfold_plugin'
]

load_plugin_library()
39 changes: 39 additions & 0 deletions torch2trt_dynamic/plugins/create_torchunfold_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import tensorrt as trt


def create_torchunfold_plugin(layer_name, kernel_size, dilation, padding,
stride):

creator = trt.get_plugin_registry().get_plugin_creator(
'TorchUnfoldPluginDynamic', '1', '')

pfc = trt.PluginFieldCollection()

if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
pf_kernel_size = trt.PluginField('kernel_size',
np.array(kernel_size, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_kernel_size)

if isinstance(dilation, int):
dilation = (dilation, dilation)
pf_dilation = trt.PluginField('dilation',
np.array(dilation, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_dilation)

if isinstance(padding, int):
padding = (padding, padding)
pf_padding = trt.PluginField('padding', np.array(padding, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_padding)

if isinstance(stride, int):
stride = (stride, stride)
pf_stride = trt.PluginField('stride', np.array(stride, dtype=np.int32),
trt.PluginFieldType.INT32)
pfc.append(pf_stride)

return creator.create_plugin(layer_name, pfc)

0 comments on commit 8cf19f2

Please sign in to comment.