diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db1..221745b1c26e 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -12,12 +12,25 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + #include namespace mlir { namespace torch { + +/// Collect a set of legal/illegal ops for converting Torch operations to Tosa +/// dialect. +void populateTorchToTosaConversionLegalOps(ConversionTarget &target); + +/// Collect a set of patterns to convert Torch operations to Tosa dialect + +/// return the set of illegalOps +std::set +populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter, + RewritePatternSet &patterns); + std::unique_ptr> createConvertTorchToTosaPass(); -} +} // namespace torch } // namespace mlir #endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index d21dd5504dcd..264fb4966d39 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize); +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7446b7faaa08..13f555c146b4 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1093,18 +1093,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.replaceOp(binder.op, nllLoss); return success(); }); - patterns.onOp("NonZero", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) { - return failure(); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + auto rawSize = resultType.getSizes(); + SmallVector torchResultSize(rawSize.rbegin(), rawSize.rend()); + auto torchResultType = rewriter.getType( + torchResultSize, resultType.getDtype()); + auto nonZero = rewriter.create( + binder.getLoc(), torchResultType, operand); + // The output tensor has a shape of ((n, z)), where (n) is the + // number of dimensions in the input tensor and (z) is the + // number of non-zero elements2. This is different from + // PyTorch's default behavior, where the dimensions are + // reversed. + rewriter.replaceOpWithNewOp( + binder.op, resultType, nonZero, zero, one); + return success(); + }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a18c0bae01fc..b8c20bc73f65 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Value input = adaptor.getSelf(); - auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); - - if (inputRank == 0) { - return rewriter.notifyMatchFailure( - op, "zero input rank should have been handled by the folder"); - } - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - - // assert dynamic squeeze dim size == 1 - if (inputType.isDynamicDim(dim)) { - Value cstDim = rewriter.create(op.getLoc(), dim); - Value dimVal = rewriter.create(op.getLoc(), input, cstDim); - Value cstOne = rewriter.create(op.getLoc(), 1); - Value cmp = rewriter.create( - op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne); - rewriter.create( - op.getLoc(), cmp, - rewriter.getStringAttr( - "Expected dynamic squeeze dim size to be statically 1")); - } - - const TypeConverter *typeConverter = getTypeConverter(); - auto resultType = - cast(typeConverter->convertType(op.getType())); - int64_t resultRank = resultType.getRank(); - // If the dim(th) dimension of operand tensor type is not statically unit, - // `aten.squeeze` will behave as an identity operation. - if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { - rewriter.replaceOpWithNewOp(op, resultType, input); - return success(); + auto squeezeTensorInfo = + squeezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(squeezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - SmallVector reassociationMap(resultRank); - bool alreadyCrossedSqueezedDim = false; - for (int i = 0; i != resultRank; i++) { - if (alreadyCrossedSqueezedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (dim != 0 && i != dim - 1) - continue; - - alreadyCrossedSqueezedDim = true; - if (dim == 0) - reassociationMap[0].push_back(1); - if (i == dim - 1) - reassociationMap[i].push_back(dim); - } - } - // Note: In case the operand tensor type is of unit rank and is statically - // shaped with unit dimension, the `reassociationMap` will be empty and the - // input will be collapsed to a 0-D tensor. - rewriter.replaceOpWithNewOp(op, resultType, input, - reassociationMap); + rewriter.replaceOp(op, squeezeTensorInfo.value()); return success(); } }; @@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - auto inputRank = - cast(adaptor.getSelf().getType()).getRank(); - dim = toPositiveDim(dim, inputRank + 1); - if (!isValidDim(dim, inputRank + 1)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector reassociationMap(inputRank); - // From the perspective of the reassociation map, the situation of - // unsqueezing before or after the last dimension is symmetrical. - // Normalize it to the "before" case. - // The 0 case is special here, since there is no last dimension to insert - // before -- we simply rely on the loop below iterating 0 times. - if (dim == inputRank && inputRank != 0) - dim = inputRank - 1; - bool alreadyCrossedExpandedDim = false; - for (int i = 0; i != inputRank; i++) { - if (alreadyCrossedExpandedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (i == dim) { - reassociationMap[i].push_back(i + 1); - alreadyCrossedExpandedDim = true; - } - } + auto unsqueezeTensorInfo = + unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(unsqueezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - auto resultType = cast( - getTypeConverter()->convertType(op->getResult(0).getType())); - rewriter.replaceOpWithNewOp( - op, resultType, adaptor.getSelf(), reassociationMap); + + rewriter.replaceOp(op, unsqueezeTensorInfo.value()); return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index ceedb1beb983..1c657faaf11d 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -850,6 +850,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int dilations"); + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); + + // Adding support for 1d group convolution by converting the 1d-conv to + // 2d-conv. + // TODO: Replace this logic with the appropriate linalg op for 1-d group + // convolution once that support is added. + bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1); + if (is1DGroupConv) { + // Unsqueezing the last dim of input and weight. Also extending the + // dilation, stride, padding, and output padding lists. + auto unsqueezeInputInfo = + unsqueezeTensor(rewriter, op, input, /*dim=*/-1); + if (failed(unsqueezeInputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + input = unsqueezeInputInfo.value(); + + auto unsqueezeWeightInfo = + unsqueezeTensor(rewriter, op, weight, /*dim=*/-1); + if (failed(unsqueezeWeightInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + weight = unsqueezeWeightInfo.value(); + + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + paddingIntValues.push_back(cstZero); + outputPaddingIntValues.push_back(cstZero); + strideInts.push_back(1); + dilationInts.push_back(1); + + inRank++; + numSpatialDims++; + } + Value inBatch = getDimOp(rewriter, loc, input, 0); Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; @@ -861,13 +903,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); - // Checks for valid group size - int64_t numGroups; - if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) - return rewriter.notifyMatchFailure(op, - "only constant group size supported."); - Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); - auto validate = [&](Value toValidate, std::string err) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); @@ -1286,13 +1321,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeOutputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate squeeze tensor"); + } + conv = squeezeOutputInfo.value(); + } + rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } if (numSpatialDims != 2) return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); + op, "unimplemented: only 1D and 2D grouped convolution supported"); // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { @@ -1377,6 +1423,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeOutputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate squeeze tensor"); + } + conv = squeezeOutputInfo.value(); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index abd3160f1048..65bd3cba7159 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8474,347 +8474,356 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { ConversionTarget target(*context); target.addLegalDialect(); + target.addIllegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - // The following ops are never the primary reason why lowering fails. - // The backend contract only allows functions to return tensors thus there - // is always another op using them. - // When we have a chain of torch.constant.int followed by a unsupported - // torch op, we want the pass to mention the unsupported torch op - // in the error message. - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addIllegalDialect(); + populateTorchToTosaConversionLegalOps(target); RewritePatternSet patterns(context); + auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps( + typeConverter, patterns); + + for (auto op : illegalOps) { + target.addIllegalOp(OperationName(op, context)); + } + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { + // The following ops are never the primary reason why lowering fails. + // The backend contract only allows functions to return tensors thus there + // is always another op using them. + // When we have a chain of torch.constant.int followed by a unsupported + // torch op, we want the pass to mention the unsupported torch op + // in the error message. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); +} + +std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + + MLIRContext *context = patterns.getContext(); + std::set illegalOps; + #define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN #define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) - INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) - INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) - INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) - INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) - INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) - INSERT_UNARY_PATTERN(AtenErfOp, tosa::ErfOp) - INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) - INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) - INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) + INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) + INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) + INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) + INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) + INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) + INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) + INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) + INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) + INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) - INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) - INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) - INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) - INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) - INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) - INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, - tosa::LogicalLeftShiftOp) - INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, - tosa::ArithmeticRightShiftOp) + INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) + INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) + INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) + INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) + INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp) + INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, + tosa::ArithmeticRightShiftOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) #undef INSERT_BINARY_COMPARE_PATTERN #define INSERT_BINARY_MUL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); - INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); + INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); + INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); #undef INSERT_BINARY_MUL_PATTERN #define INSERT_BINARY_DIV_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); #undef INSERT_BINARY_DIV_PATTERN #define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); #undef INSERT_REMAINDER_FMOD_OP_PATTERN #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, - mlir::tosa::convertReduceMeanOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, - mlir::tosa::convertReduceSumOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, - mlir::tosa::convertLinalgVectorNormOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, + mlir::tosa::convertReduceMeanOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, + mlir::tosa::convertReduceSumOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, + mlir::tosa::convertLinalgVectorNormOp) #undef INSERT_NDIMS_REDUCTION_OP_PATTERN #define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, - mlir::tosa::convertReduceAllOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, - mlir::tosa::convertReduceProdOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, + mlir::tosa::convertReduceAnyOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, + mlir::tosa::convertReduceAllOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ONEDIM_REDUCTION_OP_PATTERN #define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, - mlir::tosa::convertReduceAllOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, - mlir::tosa::convertReduceSumOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, - mlir::tosa::convertReduceMaxOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, - mlir::tosa::convertReduceMinOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, - mlir::tosa::convertReduceProdOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN #define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); #undef INSERT_INDICES_REDUCTION_OP_PATTERN #define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) #undef INSERT_SQUEEZE_OP_PATTERN #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); #undef INSERT_MATMUL_ATEMOP_PATTERN #define INSERT_MM_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MM_ATENOP_PATTERN(AtenMmOp); - INSERT_MM_ATENOP_PATTERN(AtenBmmOp); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); #undef INSERT_MM_ATEMOP_PATTERN #define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); #undef INSERT_LINEAR_ATEMOP_PATTERN #define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, - tosa::AvgPool2dOp); + INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, + tosa::AvgPool2dOp); #undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenMaxPool2dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenMaxPool1dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenAvgPool2dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenAvgPool1dOp::getOperationName()); + patterns.add(typeConverter, context); #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); - INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); - INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); + INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_FILL_PATTERN(AtenFill_ScalarOp); - INSERT_FILL_PATTERN(AtenFillScalarOp); - INSERT_FILL_PATTERN(AtenFillTensorOp); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); #undef INSERT_FILL_PATTERN #define INSERT_MASKED_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); #undef INSERT_MASKED_FILL_PATTERN #define INSERT_POW_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); - INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); - INSERT_POW_OP_PATTERN(AtenPowScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); + INSERT_POW_OP_PATTERN(AtenPowScalarOp); #undef INSERT_POW_OP_PATTERN +#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); +#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN + #define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); #undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN -#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); -#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN - #define INSERT_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); - INSERT_ATENOP_PATTERN(AtenReluOp); - INSERT_ATENOP_PATTERN(AtenLeakyReluOp); - INSERT_ATENOP_PATTERN(AtenArgmaxOp); - INSERT_ATENOP_PATTERN(AtenRsubScalarOp); - INSERT_ATENOP_PATTERN(AtenConvolutionOp); - INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); - INSERT_ATENOP_PATTERN(AtenReshapeOp); - INSERT_ATENOP_PATTERN(AtenBatchNormOp); - INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); - INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); - INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); - INSERT_ATENOP_PATTERN(AtenPermuteOp); - INSERT_ATENOP_PATTERN(AtenLog2Op); - INSERT_ATENOP_PATTERN(AtenThresholdOp); - INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); - INSERT_ATENOP_PATTERN(AtenContiguousOp); - INSERT_ATENOP_PATTERN(AtenDropoutOp); - INSERT_ATENOP_PATTERN(AtenViewOp); - INSERT_ATENOP_PATTERN(AtenGeluOp); - INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); - INSERT_ATENOP_PATTERN(AtenEmbeddingOp); - INSERT_ATENOP_PATTERN(AtenTransposeIntOp); - INSERT_ATENOP_PATTERN(AtenSliceTensorOp); - INSERT_ATENOP_PATTERN(AtenBroadcastToOp); - INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenAbsOp); - INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenClampOp); - INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); - INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenCopyOp); - INSERT_ATENOP_PATTERN(AtenToDtypeOp); - INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); - INSERT_ATENOP_PATTERN(AtenCatOp); - INSERT_ATENOP_PATTERN(AtenSqrtOp); - INSERT_ATENOP_PATTERN(AtenIscloseOp); - INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); - INSERT_ATENOP_PATTERN(AtenTrilOp); - INSERT_ATENOP_PATTERN(AtenDiagonalOp); - INSERT_ATENOP_PATTERN(AtenIndexSelectOp); - INSERT_ATENOP_PATTERN(AtenFlipOp); - INSERT_ATENOP_PATTERN(AtenRoundOp); - INSERT_ATENOP_PATTERN(AtenScatterSrcOp); - INSERT_ATENOP_PATTERN(AtenSliceScatterOp); - INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); - INSERT_ATENOP_PATTERN(AtenUniformOp); - INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); - INSERT_ATENOP_PATTERN(AtenAsStridedOp); - INSERT_ATENOP_PATTERN(AtenClampTensorOp); - INSERT_ATENOP_PATTERN(PrimsCollapseOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); - INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); - INSERT_ATENOP_PATTERN(PrimsSplitDimOp); - INSERT_ATENOP_PATTERN(AtenOuterOp); - INSERT_ATENOP_PATTERN(AtenLogitOp); - INSERT_ATENOP_PATTERN(AtenLog1pOp); - INSERT_ATENOP_PATTERN(AtenLog10Op); - INSERT_ATENOP_PATTERN(AtenTanOp); + INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); + INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenLeakyReluOp); + INSERT_ATENOP_PATTERN(AtenArgmaxOp); + INSERT_ATENOP_PATTERN(AtenRsubScalarOp); + INSERT_ATENOP_PATTERN(AtenConvolutionOp); + INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenReshapeOp); + INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenThresholdOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); + INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenDropoutOp); + INSERT_ATENOP_PATTERN(AtenViewOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); + INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenTransposeIntOp); + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenAbsOp); + INSERT_ATENOP_PATTERN(AtenWhereSelfOp); + INSERT_ATENOP_PATTERN(AtenClampOp); + INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenCopyOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); + INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); + INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); + INSERT_ATENOP_PATTERN(AtenTrilOp); + INSERT_ATENOP_PATTERN(AtenDiagonalOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); + INSERT_ATENOP_PATTERN(AtenAsStridedOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); + INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); + INSERT_ATENOP_PATTERN(AtenOuterOp); + INSERT_ATENOP_PATTERN(AtenLogitOp); + INSERT_ATENOP_PATTERN(AtenLog1pOp); + INSERT_ATENOP_PATTERN(AtenLog10Op); + INSERT_ATENOP_PATTERN(AtenTanOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); + INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); #undef INSERT_CLONE_ATENOP_PATTERN - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; -} // namespace + return illegalOps; +} std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() { diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index e3f5b6d0299a..72217e5f4afd 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -447,6 +447,119 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, return castIntToIndex(rewriter, loc, boundedByDimSize); } +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + ArrayRef inputShape = inputType.getShape(); + + // `input` has a reduced rank. Hence add 1. + int64_t unsqueezedRank = inputShape.size() + 1; + dim = toPositiveDim(dim, unsqueezedRank); + if (!isValidDim(dim, unsqueezedRank)) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + + SmallVector unsqueezedShape{inputShape}; + unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1); + Type unsqueezedType = + RankedTensorType::get(unsqueezedShape, inputType.getElementType()); + + SmallVector reassociationMap(inputRank); + // From the perspective of the reassociation map, the situation of + // unsqueezing before or after the last dimension is symmetrical. + // Normalize it to the "before" case. + // The 0 case is special here, since there is no last dimension to insert + // before -- we simply rely on the loop below iterating 0 times. + if (dim == inputRank && inputRank != 0) + dim = inputRank - 1; + bool alreadyCrossedExpandedDim = false; + for (int i = 0; i != inputRank; i++) { + if (alreadyCrossedExpandedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (i == dim) { + reassociationMap[i].push_back(i + 1); + alreadyCrossedExpandedDim = true; + } + } + } + Value unsqueezed = rewriter.create( + op->getLoc(), unsqueezedType, input, reassociationMap); + return unsqueezed; +} + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + Location loc = op->getLoc(); + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + // No scope for squeezing the input. + if (inputRank == 0) + return input; + + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + // assert dynamic squeeze dim size == 1 + if (inputType.isDynamicDim(dim)) { + Value cstDim = rewriter.create(loc, dim); + Value dimVal = rewriter.create(loc, input, cstDim); + Value cstOne = rewriter.create(loc, 1); + Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, + dimVal, cstOne); + rewriter.create( + loc, cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); + } + + ArrayRef inputShape = inputType.getShape(); + SmallVector squeezedShape; + squeezedShape.append(inputShape.begin(), inputShape.begin() + dim); + squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end()); + int64_t squeezedRank = inputRank - 1; + Type squeezedType = + RankedTensorType::get(squeezedShape, inputType.getElementType()); + + // If the dim(th) dimension of operand tensor type is not statically unit, + // squeeze will behave as an identity operation. + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { + return input; + } + + SmallVector reassociationMap(squeezedRank); + bool alreadyCrossedSqueezedDim = false; + for (int i = 0; i != squeezedRank; i++) { + if (alreadyCrossedSqueezedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (dim != 0 && i != dim - 1) + continue; + + alreadyCrossedSqueezedDim = true; + if (dim == 0) + reassociationMap[0].push_back(1); + if (i == dim - 1) + reassociationMap[i].push_back(dim); + } + } + // Note: In case the operand tensor type is of unit rank and is statically + // shaped with unit dimension, the `reassociationMap` will be empty and the + // input will be collapsed to a 0-D tensor. + Value squeezed = rewriter.create( + op->getLoc(), squeezedType, input, reassociationMap); + return squeezed; +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 79347fbeb0d6..d9d7ef1a0cd4 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -53,8 +53,8 @@ void mlir::torch::registerOptionalInputDialects( mlir::DialectRegistry ®istry) { registry.insert(); + scf::SCFDialect, sparse_tensor::SparseTensorDialect, + tensor::TensorDialect, tosa::TosaDialect>(); } void mlir::torch::registerAllPasses() { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e85d0c719aef..cc1ca5a04110 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -417,7 +417,6 @@ "CeilFloatModule_basic", "ContainsIntList_False", "ContainsIntList_True", - "Conv1dNoPaddingGroupModule_basic", "ConvTbcModule_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", @@ -2826,6 +2825,7 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -2981,6 +2981,7 @@ "Conv1dModule_basic", "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -3450,7 +3451,6 @@ "Unfold_Module_Rank_Zero_Size_Zero_basic", "Unfold_Module_Dynamic_basic", "ViewDtypeStaticModule_basic", - "Conv1dNoPaddingGroupModule_basic", "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", "PowIntIntModule_basic", "PrimsSumFloatModule_basic", @@ -3695,6 +3695,7 @@ "Conv2dWithPaddingModule_basic", "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -4380,6 +4381,7 @@ "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index adc37393f393..7cf6226890f6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1316,6 +1316,33 @@ def Conv1dWithValidPaddingModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv1dGroupModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d( + inputVec, weight, bias=bias, stride=[1], padding=[0], dilation=[1], groups=2 + ) + + +@register_test_case(module_factory=lambda: Conv1dGroupModule()) +def Conv1dGroupModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 5a5fb83d5fc0..7f1e63d83ccd 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1580,12 +1580,14 @@ func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor // ----- -// CHECK-LABEL: func.func @test_nonzero - func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> - %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> - return %0 : !torch.vtensor<[3,4,5],si64> - } +func.func @test_nonzero(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[NONZERO:.*]] = torch.aten.nonzero %arg0 : !torch.vtensor<[?],f32> -> !torch.vtensor<[?,1],si64> + // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[NONZERO]], %[[ZERO]], %[[ONE]] : !torch.vtensor<[?,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64> + %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> + return %0 : !torch.vtensor<[1,?],si64> +} // -----