Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of f03a5762 (Dec 12) (136) #526

Merged
merged 11 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,25 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include <memory>

namespace mlir {
namespace torch {

/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
/// dialect.
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);

/// Collect a set of patterns to convert Torch operations to Tosa dialect +
/// return the set of illegalOps
std::set<StringRef>
populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
RewritePatternSet &patterns);

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
9 changes: 9 additions & 0 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize);

// Helper function to unsqueeze the input tensor at given dim.
// Returns the unsqueezed tensor or failure.
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim);

// Helper function to squeeze the input tensor at given dim.
// Returns the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, int64_t dim);
} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
41 changes: 29 additions & 12 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,18 +1093,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.replaceOp(binder.op, nllLoss);
return success();
});
patterns.onOp("NonZero", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::AtenNonzeroOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value one = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
auto rawSize = resultType.getSizes();
SmallVector<int64_t> torchResultSize(rawSize.rbegin(), rawSize.rend());
auto torchResultType = rewriter.getType<Torch::ValueTensorType>(
torchResultSize, resultType.getDtype());
auto nonZero = rewriter.create<Torch::AtenNonzeroOp>(
binder.getLoc(), torchResultType, operand);
// The output tensor has a shape of ((n, z)), where (n) is the
// number of dimensions in the input tensor and (z) is the
// number of non-zero elements2. This is different from
// PyTorch's default behavior, where the dimensions are
// reversed.
rewriter.replaceOpWithNewOp<Torch::AtenTransposeIntOp>(
binder.op, resultType, nonZero, zero, one);
return success();
});
patterns.onOp(
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
Expand Down
98 changes: 13 additions & 85 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value input = adaptor.getSelf();
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();

if (inputRank == 0) {
return rewriter.notifyMatchFailure(
op, "zero input rank should have been handled by the folder");
}

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

// assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) {
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
Value cmp = rewriter.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
rewriter.create<cf::AssertOp>(
op.getLoc(), cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
}

const TypeConverter *typeConverter = getTypeConverter();
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
int64_t resultRank = resultType.getRank();

// If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
auto squeezeTensorInfo =
squeezeTensor(rewriter, op, adaptor.getSelf(), dim);
if (failed(squeezeTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}

SmallVector<ReassociationIndices> reassociationMap(resultRank);
bool alreadyCrossedSqueezedDim = false;
for (int i = 0; i != resultRank; i++) {
if (alreadyCrossedSqueezedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (dim != 0 && i != dim - 1)
continue;

alreadyCrossedSqueezedDim = true;
if (dim == 0)
reassociationMap[0].push_back(1);
if (i == dim - 1)
reassociationMap[i].push_back(dim);
}
}
// Note: In case the operand tensor type is of unit rank and is statically
// shaped with unit dimension, the `reassociationMap` will be empty and the
// input will be collapsed to a 0-D tensor.
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, resultType, input,
reassociationMap);
rewriter.replaceOp(op, squeezeTensorInfo.value());
return success();
}
};
Expand All @@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
auto inputRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
dim = toPositiveDim(dim, inputRank + 1);
if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<ReassociationIndices> reassociationMap(inputRank);
// From the perspective of the reassociation map, the situation of
// unsqueezing before or after the last dimension is symmetrical.
// Normalize it to the "before" case.
// The 0 case is special here, since there is no last dimension to insert
// before -- we simply rely on the loop below iterating 0 times.
if (dim == inputRank && inputRank != 0)
dim = inputRank - 1;
bool alreadyCrossedExpandedDim = false;
for (int i = 0; i != inputRank; i++) {
if (alreadyCrossedExpandedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (i == dim) {
reassociationMap[i].push_back(i + 1);
alreadyCrossedExpandedDim = true;
}
}
auto unsqueezeTensorInfo =
unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim);
if (failed(unsqueezeTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultType, adaptor.getSelf(), reassociationMap);

rewriter.replaceOp(op, unsqueezeTensorInfo.value());
return success();
}
};
Expand Down
72 changes: 64 additions & 8 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");

// Checks for valid group size
int64_t numGroups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());

// Adding support for 1d group convolution by converting the 1d-conv to
// 2d-conv.
// TODO: Replace this logic with the appropriate linalg op for 1-d group
// convolution once that support is added.
bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1);
if (is1DGroupConv) {
// Unsqueezing the last dim of input and weight. Also extending the
// dilation, stride, padding, and output padding lists.
auto unsqueezeInputInfo =
unsqueezeTensor(rewriter, op, input, /*dim=*/-1);
if (failed(unsqueezeInputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
input = unsqueezeInputInfo.value();

auto unsqueezeWeightInfo =
unsqueezeTensor(rewriter, op, weight, /*dim=*/-1);
if (failed(unsqueezeWeightInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
weight = unsqueezeWeightInfo.value();

Value cstZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(0));
paddingIntValues.push_back(cstZero);
outputPaddingIntValues.push_back(cstZero);
strideInts.push_back(1);
dilationInts.push_back(1);

inRank++;
numSpatialDims++;
}

Value inBatch = getDimOp(rewriter, loc, input, 0);
Value inChannels = getDimOp(rewriter, loc, input, 1);
SmallVector<Value> inDims;
Expand All @@ -861,13 +903,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));

// Checks for valid group size
int64_t numGroups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
return rewriter.notifyMatchFailure(op,
"only constant group size supported.");
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());

auto validate = [&](Value toValidate, std::string err) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Expand Down Expand Up @@ -1286,13 +1321,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}

if (is1DGroupConv) {
// Squeezing the last dim of the result of conv.
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
if (failed(squeezeOutputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate squeeze tensor");
}
conv = squeezeOutputInfo.value();
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}

if (numSpatialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
op, "unimplemented: only 1D and 2D grouped convolution supported");

// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
Expand Down Expand Up @@ -1377,6 +1423,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}

if (is1DGroupConv) {
// Squeezing the last dim of the result of conv.
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
if (failed(squeezeOutputInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate squeeze tensor");
}
conv = squeezeOutputInfo.value();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down
Loading
Loading