Skip to content

Commit

Permalink
add conv3d and pool3d, fix unsqueeze2 (#405)
Browse files Browse the repository at this point in the history
* add conv3d and pool3d, fix unsqueeze2

* add AdaptiveAvgPool2D, AdaptiveAvgPool3D, MaxPool3D test and fix Conv2D, Conv3D

* add nn_AdaptiveAvgPool1D test

* add data_format assert

* update data_format assert
  • Loading branch information
yeliang2258 authored Nov 28, 2021
1 parent ebfcc0c commit ebb667e
Show file tree
Hide file tree
Showing 11 changed files with 713 additions and 36 deletions.
4 changes: 3 additions & 1 deletion docs/en/op_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
| concat | 1~12 |
| conv2d | 1~12 |
| conv2d_transpose | 1~12 |
| conv3d | 1~12 |
| depthwise_conv2d_transpose | 1~12 |
| collect_fpn_proposals | 11~12 |
| cumsum | 11~12 |
| deformable_conv | 11~12 |
| depthwise_conv2d | 1~12 |
| distribute_fpn_proposals | 11~12 |
| dist | 7~12 |
| dist | 7~12 |
| dropout | 7~12 |
| dot | 7~13 |
| elementwise_add | 7~12 |
Expand Down Expand Up @@ -111,6 +112,7 @@
| pad3d | 1~12 |
| pixel_shuffle | 11~12 |
| pool2d | 1~12 | limited supported |
| pool3d | 1~12 | limited supported |
| pow | 8~12 |
| prior_box | 1~12 |
| prelu | 1~12 |
Expand Down
4 changes: 3 additions & 1 deletion docs/zh/op_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
| concat | 1~12 |
| conv2d | 1~12 |
| conv2d_transpose | 1~12 |
| conv3d | 1~12 |
| collect_fpn_proposals | 11~12 |
| cumsum | 11~12 |
| deformable_conv | 11~12 |
| depthwise_conv2d | 1~12 |
| depthwise_conv2d_transpose | 1~12 |
| distribute_fpn_proposals | 11~12 |
| dist | 7~12 |
| dist | 7~12 |
| dropout | 7~12 |
| dot | 7~13 |
| elementwise_add | 7~12 |
Expand Down Expand Up @@ -111,6 +112,7 @@
| pad3d | 1~12 |
| pixel_shuffle | 11~12 |
| pool2d | 1~12 | limited supported |
| pool3d | 1~12 | limited supported |
| pow | 8~12 |
| prior_box | 1~12 |
| prelu | 1~12 |
Expand Down
126 changes: 119 additions & 7 deletions paddle2onnx/op_mapper/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,28 @@
from paddle2onnx import utils


@op_mapper(['conv2d', 'depthwise_conv2d'])
@op_mapper(['conv2d', 'depthwise_conv2d', 'conv3d'])
class Conv():
support_opset_version_range = (1, 12)

@classmethod
def opset_1(cls, graph, node, **kw):
kernel_shape = node.input_shape('Filter', 0)
dilations = node.attr('dilations')
kernel_shape = kernel_shape[-2:]
kernel_shape = kernel_shape[2:]
strides = node.attr('strides')
group = node.attr('groups')
pads = node.attr('paddings')
assert node.attrs['data_format'] == 'NCHW', "The conv data format should be 'NCHW', but received data format " \
"is %s." % node.attrs['data_format']
assert node.attrs['data_format'] == 'NCHW' or node.attrs['data_format'] == 'NCDHW', \
"The conv data format should be 'NCHW' or 'NCDHW', but received data format " \
"is %s." % node.attrs['data_format']
# onnx padding is [x1_begin, x2_begin...x1_end, x2_end, ...]
if len(pads) == 4:
pads = [pads[i] for i in [0, 2, 1, 3]]
if len(pads) == 2:
if len(pads) == 2 or len(pads) == 3:
pads = pads + pads
elif len(pads) == 4:
pads = [pads[i] for i in [0, 2, 1, 3]]
elif len(pads) == 6:
pads = [pads[i] for i in [0, 2, 4, 1, 3, 5]]
attrs = {
'dilations': dilations,
'kernel_shape': kernel_shape,
Expand Down Expand Up @@ -114,6 +117,9 @@ def is_same_span(cls, in_size, out_size):

@classmethod
def opset_1(cls, graph, node, **kw):
assert node.attrs['data_format'] == 'NCHW', \
"The conv data format should be 'NCHW', but received data format " \
"is %s." % node.attrs['data_format']
if node.attr('global_pooling') or (node.attr('adaptive') and
node.attr('ksize') == [1, 1]):
onnx_node = graph.make_node(
Expand Down Expand Up @@ -190,6 +196,112 @@ def opset_1(cls, graph, node, **kw):
attrs=attrs)


@op_mapper('pool3d')
class Pool3D():
support_opset_version_range = (1, 12)
pool_type = {
'max': ('MaxPool', 'GlobalMaxPool'),
'avg': ('AveragePool', 'GlobalAveragePool')
}

@classmethod
def is_same_span(cls, in_size, out_size):
spans = []
for i in range(out_size):
start = math.floor(i * (in_size / out_size))
end = math.ceil((i + 1) * (in_size / out_size))
spans.append(end - start)
if len(set(spans)) == 1:
return True
return False

@classmethod
def opset_1(cls, graph, node, **kw):
assert node.attrs['data_format'] == 'NCDHW', \
"The conv data format should be 'NCDHW', but received data format " \
"is %s." % node.attrs['data_format']

if node.attr('global_pooling') or (node.attr('adaptive') and
node.attr('ksize') == [1, 1, 1]):
onnx_node = graph.make_node(
cls.pool_type[node.attr('pooling_type')][1],
inputs=node.input('X'),
outputs=node.output('Out'))
elif node.attr('adaptive'):
# if pool is adaptive, check if input shape of pool is fixed.
mapper_helper.is_static_shape(node.input_shape('X', 0))
input_d, input_h, input_w = node.input_shape('X', 0)[2:]
output_d, output_h, output_w = node.output_shape('Out', 0)[2:]
stride_d = int(input_d / output_d)
stride_h = int(input_h / output_h)
stride_w = int(input_w / output_w)

kernel_d = input_d - (output_d - 1) * stride_d
kernel_h = input_h - (output_h - 1) * stride_h
kernel_w = input_w - (output_w - 1) * stride_w

#check if kernel_size is fixed.
if not cls.is_same_span(input_h, output_h) or not cls.is_same_span(
input_w, output_w) or not cls.is_same_span(input_d,
output_d):
raise Exception(
"Cannot convert adaptive pool with input_size: {}, output_size: {}"
.format(
node.input_shape('X', 0), node.output_shape('Out', 0)))
else:
attrs = {
'kernel_shape': (kernel_d, kernel_h, kernel_w),
'strides': (stride_d, stride_h, stride_w),
}
if node.attr('ceil_mode') and graph.opset_version < 10:
raise Exception(
"Cannot convert pool with ceil_model == True to ONNX Opset version < 10."
)
elif graph.opset_version > 10:
attrs['ceil_mode'] = node.attr('ceil_mode')
auto_pad = node.attr('padding_algorithm')
if auto_pad == 'SAME':
attrs['auto_pad'] = 'SAME_UPPER'
elif auto_pad == 'VALID':
attrs['auto_pad'] = 'VALID'
if node.attr('pooling_type') == 'avg':
attrs['count_include_pad'] = not node.attr('exclusive')
onnx_node = graph.make_node(
cls.pool_type[node.attr('pooling_type')][0],
inputs=node.input('X'),
outputs=node.output('Out'),
attrs=attrs)
else:
input_shape = node.input_shape('X', 0)
k_size = node.attr('ksize')
paddings = node.attr('paddings')
if input_shape[2] > 0 and input_shape[2] + paddings[0] < k_size[0]:
k_size[0] = input_shape[2] + paddings[0]
if input_shape[3] > 0 and input_shape[3] + paddings[1] < k_size[1]:
k_size[1] = input_shape[3] + paddings[1]
if input_shape[4] > 0 and input_shape[4] + paddings[2] < k_size[2]:
k_size[2] = input_shape[4] + paddings[2]
attrs = {
'kernel_shape': k_size,
'strides': node.attr('strides'),
'pads': node.attr('paddings') + node.attr('paddings'),
}
if node.attr('ceil_mode') and graph.opset_version < 10:
raise Exception(
"Cannot convert pool with ceil_model == True to ONNX Opset version < 10"
)
elif graph.opset_version >= 10:
attrs['ceil_mode'] = node.attr('ceil_mode')

if node.attr('pooling_type') == 'avg':
attrs['count_include_pad'] = not node.attr('exclusive')
onnx_node = graph.make_node(
cls.pool_type[node.attr('pooling_type')][0],
inputs=node.input('X'),
outputs=node.output('Out'),
attrs=attrs)


@op_mapper('elu')
class ELU():
support_opset_version_range = (1, 12)
Expand Down
49 changes: 28 additions & 21 deletions paddle2onnx/op_mapper/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def opset_1(cls, graph, node, **kw):
axis = axis + len(node.input_shape('X', 0))

node = graph.make_node(
'Concat',
inputs=inputs,
outputs=node.output('Out'),
axis=axis)
'Concat', inputs=inputs, outputs=node.output('Out'), axis=axis)


@op_mapper('assign')
Expand Down Expand Up @@ -83,6 +80,7 @@ def opset_1(cls, graph, node, **kw):
outputs=node.output('Y'),
axis=axis)


@op_mapper('unstack')
class Unstack():
support_opset_version_range = (1, 12)
Expand All @@ -91,10 +89,10 @@ class Unstack():
def opset_1(cls, graph, node, **kw):
print(node)
graph.make_node(
'Split',
inputs=node.input('X'),
outputs=node.output('Y'),
axis=node.attr('axis'))
'Split',
inputs=node.input('X'),
outputs=node.output('Y'),
axis=node.attr('axis'))


@op_mapper('expand_as_v2')
Expand Down Expand Up @@ -769,11 +767,22 @@ class Unsqueeze():

@classmethod
def opset_1(cls, graph, node, **kw):
graph.make_node(
'Unsqueeze',
inputs=node.input('X'),
outputs=node.output('Out'),
axes=node.attr('axes'))
if len(node.attr('axes')) > 0:
graph.make_node(
'Unsqueeze',
inputs=node.input('X'),
outputs=node.output('Out'),
axes=node.attr('axes'))
else:
axis_input = node.input('AxesTensor')
for name, param in graph.parameters.items():
if name in axis_input:
axis_data = param.attribute[0].t.int64_data
graph.make_node(
'Unsqueeze',
inputs=node.input('X'),
outputs=node.output('Out'),
axes=axis_data)


@op_mapper('reciprocal')
Expand Down Expand Up @@ -894,23 +903,21 @@ def convert_padding(cls, node, **kw):
#TODO support pads is Variable
if node.attr('data_format') == 'NCHW':
onnx_paddings = [
0, 0, paddings[0], paddings[2],
0, 0, paddings[1], paddings[3]
0, 0, paddings[0], paddings[2], 0, 0, paddings[1], paddings[3]
]
elif node.attr('data_format') == 'NHWC':
onnx_paddings = [
0, paddings[0], paddings[2], 0,
0, paddings[1], paddings[3], 0
0, paddings[0], paddings[2], 0, 0, paddings[1], paddings[3], 0
]
elif node.attr('data_format') == 'NCDHW':
onnx_paddings = [
0, 0, paddings[4], paddings[2], paddings[0],
0, 0, paddings[5], paddings[3], paddings[1]
0, 0, paddings[4], paddings[2], paddings[0], 0, 0, paddings[5],
paddings[3], paddings[1]
]
elif node.attr('data_format') == 'NDHWC':
onnx_paddings = [
0, paddings[4], paddings[2], paddings[0], 0,
0, paddings[5], paddings[3], paddings[1], 0
0, paddings[4], paddings[2], paddings[0], 0, 0, paddings[5],
paddings[3], paddings[1], 0
]
return onnx_paddings

Expand Down
50 changes: 50 additions & 0 deletions tests/test_nn_AdaptiveAvgPool1D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from onnxbase import APIOnnx
from onnxbase import randtool


class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()
self._avg_pool = paddle.nn.AdaptiveAvgPool1D(output_size=3)

def forward(self, inputs):
"""
forward
"""
x = self._avg_pool(inputs)
return x


def test_AdaptiveAvgPool1D_base():
"""
api: paddle.nn.AdaptiveAvgPool1D
op version: 9, 10, 11, 12
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'nn_AdaptiveAvgPool1D', [9, 10, 11, 12])
obj.set_input_data(
"input_data",
paddle.to_tensor(
randtool("float", -1, 1, [3, 1, 10]).astype('float32')))
obj.run()
Loading

0 comments on commit ebb667e

Please sign in to comment.