Skip to content

Commit

Permalink
support more than 3D ncnn:softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghaoqi committed Nov 27, 2023
1 parent ba191be commit 748ade1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
38 changes: 36 additions & 2 deletions modules/Nncase.Modules.Ncnn/Passes/Rules/Ncnn/LowerSoftmax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Nncase.PatternMatch;

using static Nncase.IR.F.Ncnn;
using static Nncase.IR.F.Tensors;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.Utility;
Expand All @@ -23,12 +24,45 @@ public partial class LowerSoftmax : RewriteRule<Pattern>
{
/// <inheritdoc/>
public override Pattern Pattern { get; } = IsSoftmax(
IsWildcard("input") with { TypePattern = IsFloat() & HasRank(x => x <= 3) },
IsWildcard("input") with { TypePattern = IsFloat() },
IsTensorConst("axis"));

// squeeze softmax to 3D,set axis to 1
private (List<int> NewShape, int NewAxis) GetFixedShapeAndAxis(List<int> oldShape, int oldAxis)
{
int positive_axis = oldAxis < 0 ? oldShape.Count + oldAxis : oldAxis;
var newShape = new List<int> { 1, oldShape[positive_axis], 1 };
for (int i = 0; i < positive_axis; i++)
{
newShape[0] *= oldShape[i];
}

for (int i = positive_axis + 1; i < oldShape.Count; i++)
{
newShape[2] *= oldShape[i];
}

return (newShape, 1);
}

private Expr? GetReplace(Expr input, int axis)
{
var newInput = new Var(input.CheckedType);
return new Call(new Fusion("ncnn", NcnnSoftmax(newInput, axis), new[] { newInput }), input);
if (input.CheckedShape.Rank > 3)
{
var (newShape, newAxis) = GetFixedShapeAndAxis(input.CheckedShape.ToValueList(), axis);

var inRes = Reshape(input, newShape.ToArray());
var inResO = new Var(inRes.CheckedType);

var ncnnSoftmaxCall = new Call(new Fusion("ncnn", NcnnSoftmax(inResO, newAxis), new[] { inResO }), inRes);

var outRes = Reshape(ncnnSoftmaxCall, input.CheckedShape.ToValueList().ToArray());
return outRes;
}
else
{
return new Call(new Fusion("ncnn", NcnnSoftmax(newInput, axis), new[] { newInput }), input);
}
}
}
5 changes: 3 additions & 2 deletions tests/importer/onnx_/basic/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def _make_module(in_shape, axis, op_version):


in_shapes = [
[2, 3, 8, 1],
[1, 3, 8, 5],
[2, 32, 3, 3],
[2, 32, 64, 128],
[1, 113, 228, 65],
]

axes = [
Expand Down

0 comments on commit 748ade1

Please sign in to comment.