Skip to content

Commit

Permalink
[ONNX] Change AtenScatterReduce to AtenScatterReduceTwoOp for onnx.Sc…
Browse files Browse the repository at this point in the history
…atterElements

This will enable the AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext
Remove the wrong AtenScatterReduce to linalg pass.
  • Loading branch information
AmosLewis committed Oct 5, 2024
1 parent e632755 commit faec62d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 169 deletions.
7 changes: 4 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,10 +645,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

Value cstStrReduction =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);

rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
binder.op, resultType, data, constAxis, indices, updates,
cstStrReduction);
cstStrReduction, cstTrue);
return success();
});
patterns.onOp(
Expand Down
148 changes: 0 additions & 148 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2124,152 +2124,6 @@ class ConvertAtenSliceScatterOp
};
} // namespace

namespace {
static Value
createLinalgPayloadForReduceScatterOp(OpBuilder &b, Location loc, Operation *op,
ValueRange payloadArgs, Value self,
int64_t dim, std::string reduceMode,
RankedTensorType resultType) {
Type resultElementType = resultType.getElementType();
Value index = castIntToIndex(b, loc, /*abstractindexElement*/ payloadArgs[0]);
// Get the element at self[index].
auto selfElement = b.create<tensor::ExtractOp>(loc, self, ValueRange{index});
// Get the element at src[index].
Value srcElement = convertScalarToDtype(
b, loc, /*abstractSrcElement*/ payloadArgs[1], resultElementType);
Value accumulatorElement;
// Reduce the elements based on different mode.
// TODO add more reduce mode here.
if (reduceMode == "sum") {
// accumulatorElement = selfElement + srcElement;
if (isa<mlir::FloatType>(resultElementType))
accumulatorElement =
b.create<arith::AddFOp>(loc, selfElement, srcElement);
else if (isa<mlir::IntegerType>(resultElementType))
accumulatorElement =
b.create<arith::AddIOp>(loc, selfElement, srcElement);
} else {
op->emitError("only sum lowering in createLinalgPayloadForReduceScatterOp");
return nullptr;
}
// Prepare source, indices, scatter_dims for scatter op.
Value accumulatorElementTensor =
b.create<tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
Value indexTensor = b.create<tensor::FromElementsOp>(loc, ValueRange{index});
ArrayRef<int64_t> dimArray{dim};
auto scatter = b.create<tensor::ScatterOp>(
loc, resultType, /*source*/ accumulatorElementTensor,
/*dest*/ self, /*indices*/ indexTensor,
/*scatter_dims*/ b.getDenseI64ArrayAttr(dimArray),
/*unique*/ b.getUnitAttr());
return scatter;
}

class ConvertAtenScatterReduceOp
: public OpConversionPattern<AtenScatterReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScatterReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

// Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
std::string reduceMode;
if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceMode)))
return rewriter.notifyMatchFailure(
op, "only support constant str reduce mode");
// TODO: add "prod", "mean", "amax", "amin" mode.
if (reduceMode != "sum")
return rewriter.notifyMatchFailure(
op, "Only support sum reduce mode for now");

// Get dim.
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");

// Prepare input.
auto self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());
int64_t selfRank = selfType.getRank();
// TODO: add more input rank support.
if (selfRank > 1 || dim > selfRank - 1)
return rewriter.notifyMatchFailure(op,
"Only support self rank==1 for now");

// Prepare index.
Value index = adaptor.getIndex();
auto indexType = cast<RankedTensorType>(index.getType());
int64_t indexRank = indexType.getRank();
SmallVector<int64_t> indexAbstractSizes(indexRank, kUnknownSize);
auto abstractIndexType =
RankedTensorType::get(makeShapeLLVMCompatible(indexAbstractSizes),
indexType.getElementType());
Value abstractindex =
rewriter.create<tensor::CastOp>(loc, abstractIndexType, index);

// Prepare src.
Value src = adaptor.getSrc();
auto srcType = cast<RankedTensorType>(src.getType());
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType = RankedTensorType::get(
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
Value abstractSrc =
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);

// Prepare result type.
const TypeConverter *typeConverter = getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

// Prepare indexingMaps and iteratorTypes.
SmallVector<AffineMap, 3> indexingMaps = {
rewriter.getMultiDimIdentityMap(indexRank),
rewriter.getMultiDimIdentityMap(srcRank),
rewriter.getMultiDimIdentityMap(selfRank),
};
// Prepare iteratorTypes.
SmallVector<utils::IteratorType> iteratorTypes{
1, utils::IteratorType::parallel};

// Implementation of scatter and reduce in linalg.generic.
bool err = false;
Value result =
rewriter
.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/self.getType(),
/*inputs=*/ValueRange({abstractindex, abstractSrc}),
/*outputs=*/self, indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
// Scatter result after reduce accumulation.
Value scatter = createLinalgPayloadForReduceScatterOp(
builder, loc, op, payloadArgs, self, dim, reduceMode,
resultType);
// Return selfElements to itself, nothing change but a
// placeholder.
if (scatter) {
builder.create<linalg::YieldOp>(
loc, /*selfElement*/ payloadArgs[2]);
}
err = !scatter;
})
.getResult(0);

if (err)
return rewriter.notifyMatchFailure(
op,
"failed to create linalg.generic operation for reduce scatter op");
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);

return success();
}
};
} // namespace

namespace {
class ConvertAtenViewAsComplexOp
: public OpConversionPattern<AtenViewAsComplexOp> {
Expand Down Expand Up @@ -2810,8 +2664,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
target.addIllegalOp<AtenSliceScatterOp>();
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
target.addIllegalOp<AtenScatterReduceOp>();
patterns.add<ConvertAtenScatterReduceOp>(typeConverter, context);
target.addIllegalOp<AtenViewAsComplexOp>();
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
target.addIllegalOp<AtenViewAsRealOp>();
Expand Down
38 changes: 20 additions & 18 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar

// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[ONE:.+]] = torch.constant.int 1
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
return %0 : !torch.vtensor<[1,5],f32>
}
Expand All @@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,

// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[ONE:.+]] = torch.constant.int 1
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
return %0 : !torch.vtensor<[1,5],f32>
}
Expand Down

0 comments on commit faec62d

Please sign in to comment.