Skip to content

Commit

Permalink
[Linalg] Add torch.scatter.reduce to linalg lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Oct 2, 2024
1 parent 9cf12e1 commit bcd5dd0
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 2 deletions.
148 changes: 148 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,152 @@ class ConvertAtenSliceScatterOp
};
} // namespace

namespace {
static Value
createLinalgPayloadForReduceScatterOp(OpBuilder &b, Location loc, Operation *op,
ValueRange payloadArgs, Value self,
int64_t dim, std::string reduceMode,
RankedTensorType resultType) {
Type resultElementType = resultType.getElementType();
Value index = castIntToIndex(b, loc, /*abstractindexElement*/ payloadArgs[0]);
// Get the element at self[index].
auto selfElement = b.create<tensor::ExtractOp>(loc, self, ValueRange{index});
// Get the element at src[index].
Value srcElement = convertScalarToDtype(
b, loc, /*abstractSrcElement*/ payloadArgs[1], resultElementType);
Value accumulatorElement;
// Reduce the elements based on different mode.
// TODO add more reduce mode here.
if (reduceMode == "sum") {
// accumulatorElement = selfElement + srcElement;
if (isa<mlir::FloatType>(resultElementType))
accumulatorElement =
b.create<arith::AddFOp>(loc, selfElement, srcElement);
else if (isa<mlir::IntegerType>(resultElementType))
accumulatorElement =
b.create<arith::AddIOp>(loc, selfElement, srcElement);
} else {
op->emitError("only sum lowering in createLinalgPayloadForReduceScatterOp");
return nullptr;
}
// Prepare source, indices, scatter_dims for scatter op.
Value accumulatorElementTensor =
b.create<tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
Value indexTensor = b.create<tensor::FromElementsOp>(loc, ValueRange{index});
ArrayRef<int64_t> dimArray{dim};
auto scatter = b.create<tensor::ScatterOp>(
loc, resultType, /*source*/ accumulatorElementTensor,
/*dest*/ self, /*indices*/ indexTensor,
/*scatter_dims*/ b.getDenseI64ArrayAttr(dimArray),
/*unique*/ b.getUnitAttr());
return scatter;
}

class ConvertAtenScatterReduceOp
: public OpConversionPattern<AtenScatterReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScatterReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

// Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
std::string reduceMode;
if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceMode)))
return rewriter.notifyMatchFailure(
op, "only support constant str reduce mode");
// TODO: add "prod", "mean", "amax", "amin" mode.
if (reduceMode != "sum")
return rewriter.notifyMatchFailure(
op, "Only support sum reduce mode for now");

// Get dim.
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");

// Prepare input.
auto self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());
int64_t selfRank = selfType.getRank();
// TODO: add more input rank support.
if (selfRank > 1 || dim > selfRank - 1)
return rewriter.notifyMatchFailure(op,
"Only support self rank==1 for now");

// Prepare index.
Value index = adaptor.getIndex();
auto indexType = cast<RankedTensorType>(index.getType());
int64_t indexRank = indexType.getRank();
SmallVector<int64_t> indexAbstractSizes(indexRank, kUnknownSize);
auto abstractIndexType =
RankedTensorType::get(makeShapeLLVMCompatible(indexAbstractSizes),
indexType.getElementType());
Value abstractindex =
rewriter.create<tensor::CastOp>(loc, abstractIndexType, index);

// Prepare src.
Value src = adaptor.getSrc();
auto srcType = cast<RankedTensorType>(src.getType());
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType = RankedTensorType::get(
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
Value abstractSrc =
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);

// Prepare result type.
const TypeConverter *typeConverter = getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));

// Prepare indexingMaps and iteratorTypes.
SmallVector<AffineMap, 3> indexingMaps = {
rewriter.getMultiDimIdentityMap(indexRank),
rewriter.getMultiDimIdentityMap(srcRank),
rewriter.getMultiDimIdentityMap(selfRank),
};
// Prepare iteratorTypes.
SmallVector<utils::IteratorType> iteratorTypes{
1, utils::IteratorType::parallel};

// Implementation of scatter and reduce in linalg.generic.
bool err = false;
Value result =
rewriter
.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/self.getType(),
/*inputs=*/ValueRange({abstractindex, abstractSrc}),
/*outputs=*/self, indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
// Scatter result after reduce accumulation.
Value scatter = createLinalgPayloadForReduceScatterOp(
builder, loc, op, payloadArgs, self, dim, reduceMode,
resultType);
// Return selfElements to itself, nothing change but a
// placeholder.
if (scatter) {
builder.create<linalg::YieldOp>(
loc, /*selfElement*/ payloadArgs[2]);
}
err = !scatter;
})
.getResult(0);

if (err)
return rewriter.notifyMatchFailure(
op,
"failed to create linalg.generic operation for reduce scatter op");
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);

return success();
}
};
} // namespace

namespace {
class ConvertAtenViewAsComplexOp
: public OpConversionPattern<AtenViewAsComplexOp> {
Expand Down Expand Up @@ -2664,6 +2810,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
target.addIllegalOp<AtenSliceScatterOp>();
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
target.addIllegalOp<AtenScatterReduceOp>();
patterns.add<ConvertAtenScatterReduceOp>(typeConverter, context);
target.addIllegalOp<AtenViewAsComplexOp>();
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
target.addIllegalOp<AtenViewAsRealOp>();
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
// CHECK: %[[STR:.*]] = torch.constant.str "add"
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
return %0 : !torch.vtensor<[1,5],f32>
Expand Down Expand Up @@ -301,7 +301,7 @@ func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
// CHECK: %[[STR:.*]] = torch.constant.str "multiply"
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
return %0 : !torch.vtensor<[1,5],f32>
Expand Down

0 comments on commit bcd5dd0

Please sign in to comment.