Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ONNX): avoids resizing unsupported dimensions #3945

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
84a4146
refactor(ONNX): prefers assignment to precede first usage in onnx.resize
bjacobgordon Jan 8, 2025
f2c3e16
refactor(ONNX): extracts `loc` within onnx.resize
bjacobgordon Jan 9, 2025
15a0eb5
refactor(ONNX): moves `rank` closer to first usage in onnx.resize
bjacobgordon Jan 10, 2025
5f4768f
refactor(ONNX): forces cast of operand in onnx.resize
bjacobgordon Jan 8, 2025
a46d47d
refactor(ONNX): loosens downcast in onnx.resize
bjacobgordon Jan 15, 2025
146b53a
refactor(ONNX): extracts `inputTensor` within onnx.resize
bjacobgordon Jan 8, 2025
2feea14
refactor(ONNX): extracts `typeOfInputTensor` from rank derivation in …
bjacobgordon Jan 8, 2025
58e3b1a
refactor(ONNX): extracts `sizesOfInputTensor` from rank derivation in…
bjacobgordon Jan 8, 2025
262b944
refactor(ONNX): renames `rank` to `rankOfInputTensor` in onnx.resize
bjacobgordon Jan 10, 2025
be47c8d
refactor(ONNX): renames `resultType` to `typeOfOutputTensor`
bjacobgordon Jan 8, 2025
deaaa4b
refactor(ONNX): enforces declaration-usage adjacency for value lists …
bjacobgordon Jan 14, 2025
3f23816
refactor(ONNX): removes redundant nulling assignment in onnx.resize
bjacobgordon Jan 14, 2025
70627dd
refactor(ONNX): renames `sizesValueList` to `filteredSizesAsOp` in on…
bjacobgordon Jan 15, 2025
59983ca
refactor(ONNX): renames `scalesValueList` to `filteredScaleFactorsAsO…
bjacobgordon Jan 15, 2025
68bae32
refactor(ONNX): renames transform helper method
bjacobgordon Jan 14, 2025
d63d92a
refactor(ONNX): renames `scaleOperand` to `proposedScaleFactorsAsOp` …
bjacobgordon Jan 14, 2025
71c804e
refactor(ONNX): renames `sizeOperand` to `proposedSizesAsOp` in onnx.…
bjacobgordon Jan 14, 2025
6d58e68
refactor(ONNX): enforces declaration-usage adjacency for `context` in…
bjacobgordon Jan 14, 2025
c0da0b0
refactor(ONNX): enforces declaration-usage adjacency for `itemList` i…
bjacobgordon Jan 14, 2025
87045db
refactor(ONNX): enforces declaration-usage adjacency for `zero` in tr…
bjacobgordon Jan 14, 2025
4565928
refactor(ONNX): extracts `loc` within transforms filter
bjacobgordon Jan 14, 2025
21d0d48
refactor(ONNX): simplifies zero op in transforms filter
bjacobgordon Jan 15, 2025
8340beb
refactor(ONNX): extracts `typeOfEveryFilteredTransformation` within t…
bjacobgordon Jan 14, 2025
7a2f746
refactor(ONNX): consolidates duplicate conditional fragments
bjacobgordon Jan 14, 2025
ed705df
refactor(ONNX): prefers direct return over declaration and assignment
bjacobgordon Jan 14, 2025
8ff7d8f
refactor(ONNX): dissolves `extract` into sole call site within transf…
bjacobgordon Jan 16, 2025
f477cd8
refactor(ONNX): mimics conditional structure in transforms filter
bjacobgordon Jan 14, 2025
e240129
refactor(ONNX): prefers braces around conditional block
bjacobgordon Jan 16, 2025
7cbf317
refactor(ONNX): removes redundant `xTy` declaration in transforms filter
bjacobgordon Jan 15, 2025
3604d8d
refactor(ONNX): merges identical conditional structures
bjacobgordon Jan 14, 2025
2e89f93
refactor(ONNX): leverages `operandType` when declaring `sizes` in tra…
bjacobgordon Jan 14, 2025
c2e9805
refactor(ONNX): inlines size when getting type of selection in transf…
bjacobgordon Jan 14, 2025
acbbb37
refactor(ONNX): simplifies `selectIndex` op in transforms filter
bjacobgordon Jan 16, 2025
41b1657
refactor(ONNX): removes redundant `auto` annotation in transforms filter
bjacobgordon Jan 15, 2025
87d6082
refactor(ONNX): captures magic number in transforms filter
bjacobgordon Jan 14, 2025
ea4e45c
refactor(ONNX): extracts `numberOfTransformableDimensions` within tra…
bjacobgordon Jan 14, 2025
357fccd
refactor(ONNX): renames `sizes` to `sizesOfTransformationVector`
bjacobgordon Jan 14, 2025
b5a743a
refactor(ONNX): renames `operandType` to `typeOfTransformationVector`…
bjacobgordon Jan 14, 2025
9728f97
refactor(ONNX): renames `operand` to `givenTransformationVector` in t…
bjacobgordon Jan 14, 2025
361fd35
refactor(ONNX): renames `i` to `eachDimension` in transforms filter
bjacobgordon Jan 21, 2025
c2bd963
refactor(ONNX): renames `selectIndex` to `eachDimensionAsOp` in trans…
bjacobgordon Jan 14, 2025
cecabb9
refactor(ONNX): renames `itemList` to `filteredTransformations` in tr…
bjacobgordon Jan 14, 2025
492151f
refactor(ONNX): renames `item` to `eachTransformation` in transforms …
bjacobgordon Jan 14, 2025
c93d81b
refactor(ONNX): renames `ext` to `selectionFromEachTransformationVect…
bjacobgordon Jan 14, 2025
0a9f919
refactor(ONNX): renames `selectResultType` to `typeOfSelectionFromTra…
bjacobgordon Jan 14, 2025
0526142
refactor(ONNX): renames `extractTy` to `typeOfEveryTransformation` in…
bjacobgordon Jan 14, 2025
f6e3d3d
refactor(ONNX): renames `zero` to `zeroAsOp` in transforms filter
bjacobgordon Jan 15, 2025
a3daf2a
fix(ONNX): avoids resizing unsupported dimensions
bjacobgordon Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 215 additions & 73 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,53 +180,126 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter,
return success();
}

Value getValueList(OpBinder binder, ConversionPatternRewriter &rewriter,
Value operand) {
SmallVector<Value> itemList;
auto sizes = dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
cast<Torch::BaseTensorType>(operand.getType());

SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = operandType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());

auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = cast<Torch::ValueTensorType>(x.getType());
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy, v);
};

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value scaleIdentityComparisonOpForFactorAtDimensionIn(
Value givenScaleFactors, int64_t givenDimension, OpBinder binder,
ConversionPatternRewriter &rewriter) {
auto typeOfScaleFactors =
cast<Torch::BaseTensorType>(givenScaleFactors.getType());

Type typeOfSelectionFromScaleFactors =
typeOfScaleFactors.getWithSizesAndDtype(
ArrayRef<int64_t>{1}, typeOfScaleFactors.getOptionalDtype());

auto loc = binder.getLoc();

Value zeroAsOp =
bjacobgordon marked this conversation as resolved.
Show resolved Hide resolved
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
bjacobgordon marked this conversation as resolved.
Show resolved Hide resolved

Value scaleIdentityAsOp = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));

Value givenDimensionAsOp = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(givenDimension));

Type typeOfScaleFactor = rewriter.getType<Torch::FloatType>();

Value selectionFromScaleFactorsAsOp = rewriter.create<Torch::AtenSelectIntOp>(
loc, typeOfSelectionFromScaleFactors, givenScaleFactors, zeroAsOp,
givenDimensionAsOp);

Value scaleFactorAsOp = rewriter.create<Torch::AtenItemOp>(
loc, typeOfScaleFactor, selectionFromScaleFactorsAsOp);

Type typeOfComparisonResult = rewriter.getType<Torch::BoolType>();

return rewriter.create<Torch::AtenEqFloatOp>(
loc, typeOfComparisonResult, scaleFactorAsOp, scaleIdentityAsOp);
}

Value originalSizeComparisonOpForSizeAtDimensionIn(
Value givenTargetSizes, Value givenOriginalTensor, int64_t givenDimension,
OpBinder binder, ConversionPatternRewriter &rewriter) {
auto typeOfTargetSizes =
cast<Torch::BaseTensorType>(givenTargetSizes.getType());

Type typeOfSelectionFromTargetSizes = typeOfTargetSizes.getWithSizesAndDtype(
ArrayRef<int64_t>{1}, typeOfTargetSizes.getOptionalDtype());

auto loc = binder.getLoc();

Value zeroAsOp =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));

Type typeOfTargetSize = rewriter.getType<Torch::IntType>();

Value givenDimensionAsOp = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(givenDimension));

Value selectionFromTargetSizesAsOp = rewriter.create<Torch::AtenSelectIntOp>(
loc, typeOfSelectionFromTargetSizes, givenTargetSizes, zeroAsOp,
givenDimensionAsOp);

Value targetSizeAsOp = rewriter.create<Torch::AtenItemOp>(
loc, typeOfTargetSize, selectionFromTargetSizesAsOp);

Value originalSizeAsOp = rewriter.create<Torch::AtenSizeIntOp>(
loc, givenOriginalTensor, givenDimensionAsOp);

Type typeOfComparisonResult = rewriter.getType<Torch::BoolType>();

return rewriter.create<Torch::AtenEqIntOp>(loc, typeOfComparisonResult,
targetSizeAsOp, originalSizeAsOp);
}
Comment on lines +183 to +252
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium structural: These two helper functions have quite a bit of overlap. If you wanted to cut out a bit of code duplication, I'd recommend passing the extractIndex and expected value (constant 1.0 for "scales" and the size op result for "sizes") as an input, and determining which of AtenEqIntOp or AtenEqFloatOp is appropriate based on the sourceValue.getDtype().

This would, in my opinion, also make the logic a bit more transparent from within the resize conversion. E.g. if you called the helper function something like extractAndCompare, the assert creation loops might look something like:

// in the sizes case
for (const auto &dim : nonResizeableDims) {
    Value cstDim = rewriter.create<Torch::ConstantIntOp>(loc, dim);
    Value expectedSize = rewriter.create<AtenSizeIntOp>(loc, inputTensor, cstDim,...);
    Value comparison = extractAndCompare(sizesTensor, /*extractIndex=*/dim, expectedSize, ...);
    // create assert ops
}

This change would make it easier to see which two things are being compared from the scope of the resize conversion, without needing to look carefully through each helper function individually. 
...

// for scales
Value expectedScale = rewriter.create<Torch::ConstantFloatOp>(...); // constant one float.
for (const auto &dim : nonResizableDims) {
    Value comparison = extractAndCompare(scalesTensor, /*extractIndex=*/dim, expectedScale,...);
    // create assert op
}


Value withUnsupportedDimensionsFilteredOut(
Value givenTransformationVector, OpBinder binder,
ConversionPatternRewriter &rewriter) {
Comment on lines +254 to +256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

important nit: This name is a bit inscrutable to me. Perhaps something like getSupportedResizeValueList is of a similar length, but a bit more direct? Not sold on my own suggestion, but I definitely think this name should be modified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a later comment addresses this a bit more, since this function doesn't need to be specific to the resize op.

auto typeOfTransformationVector =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium nit:

I don't particularly like transformation as a generic replacement for size || scale, and I would like the variable names here to be simplified considerably.

The external purpose of the first input isn't useful for understanding what this function does in isolation. You could call this input sourceTensor or sourceVector since all we are doing here is extracting values from this source and assembling them into a list construct op.

This function might also end up being useful elsewhere, and no one would be able to discern that through the overly-specific naming.

cast<Torch::BaseTensorType>(givenTransformationVector.getType());
auto sizesOfTransformationVector = typeOfTransformationVector.getSizes();
auto numberOfTransformableDimensions = sizesOfTransformationVector[0];

Type typeOfSelectionFromTransformationVector =
typeOfTransformationVector.getWithSizesAndDtype(
ArrayRef<int64_t>{1}, typeOfTransformationVector.getOptionalDtype());

auto loc = binder.getLoc();

Value zeroAsOp =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));

Type typeOfEveryTransformation;
Type typeOfEveryFilteredTransformation;
MLIRContext *context = binder.op->getContext();
for (int i = 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, operand, zero, selectIndex);
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(Torch::IntType::get(context)),
itemList);

if (isa<IntegerType>(typeOfTransformationVector.getDtype())) {
typeOfEveryTransformation = rewriter.getType<Torch::IntType>();
typeOfEveryFilteredTransformation = Torch::IntType::get(context);
} else {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(Torch::FloatType::get(context)),
itemList);
typeOfEveryTransformation = rewriter.getType<Torch::FloatType>();
typeOfEveryFilteredTransformation = Torch::FloatType::get(context);
}

SmallVector<Value> filteredTransformations;

auto numberOfUnsupportedDimensionsFromFront = 2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it might improve the extensibility of this function. You could perhaps allow passing this starting index for value extraction as an input to the function.

With something like this, a name like

Value convert1DTensorToList(
          Value sourceTensor, 
          int64_t startIndexForExtract, ...)

would make for a rather useful function in many places (maybe warranting a move to Utils).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, this is likely useful in many places in TorchOnnxToTorch because ONNX doesn't have lists of scalars (so they use rank-1 tensors), and many torch ops take lists of scalars as inputs.


for (int eachDimension = numberOfUnsupportedDimensionsFromFront;
eachDimension < numberOfTransformableDimensions; eachDimension++) {
Value eachDimensionAsOp = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(eachDimension));
Value selectionFromEachTransformationVector =
rewriter.create<Torch::AtenSelectIntOp>(
loc, typeOfSelectionFromTransformationVector,
givenTransformationVector, zeroAsOp, eachDimensionAsOp);
Value eachTransformation = rewriter.create<Torch::AtenItemOp>(
loc, typeOfEveryTransformation, selectionFromEachTransformationVector);
filteredTransformations.push_back(eachTransformation);
}
return ValueList;

return rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(typeOfEveryFilteredTransformation),
filteredTransformations);
}
} // namespace

Expand Down Expand Up @@ -2686,12 +2759,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
});
patterns.onOp(
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Torch::ValueTensorType typeOfOutputTensor;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Torch::ValueTensorType typeOfOutputTensor; is quite redundant. I know it's the type of a tensor because its declared as a Torch::ValueTensorType. Additionally, Of can be removed simply by reordering. resultType is already perfectly fine, and if you prefer output over result, then why not use outputType?

When you are contributing to existing code with existing naming conventions, I think it is only reasonable to change existing variable names if you have a very good reason to do so. I don't think this instance is one of those cases.

llvm::SmallVector<Value> operands;
std::string mode, nearest_mode, coordTfMode;
int64_t antialias, exclude_outside;
float extrapolation_value, cubic_coeff_a;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
return rewriter.notifyMatchFailure(
Expand All @@ -2706,7 +2778,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
}

if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
binder.tensorResultType(typeOfOutputTensor) ||
binder.customOpNameStringAttr(mode, "mode", "nearest") ||
binder.customOpNameStringAttr(
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
Expand All @@ -2718,6 +2790,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();

Value inputTensor = operands[0];
auto typeOfInputTensor =
cast<Torch::BaseTensorType>(inputTensor.getType());

auto sizesOfInputTensor = typeOfInputTensor.getSizes();
ArrayRef<int64_t> sizesOfOutputTensor = typeOfOutputTensor.getSizes();

int64_t const dimensionAssumedToBeBatch = 0;
int64_t const dimensionAssumedToBeChannel = 1;
int64_t nonResizableDimensions[] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly important nit:

There might be a better way to leverage the llvm array-like structures. E.g.

SmallVector<int64_t, 2> nonResizableDims{dimensionAssumedToBeBatch, dimensionAssumedToBeChannel};

This pre-allocates memory for two int64_t's and has a few other benefits. See https://llvm.org/docs/ProgrammersManual.html#llvm-adt-smallvector-h

dimensionAssumedToBeBatch,
dimensionAssumedToBeChannel,
};

auto unknownSize = Torch::kUnknownSize;

// Compile-time check for dimensions of static size
for (auto eachDimension : nonResizableDimensions) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

I'm not really sure about the choice to use each for the iteration variable. It might make reading the initial for statement a bit more natural in English, but it does make the body quite awkward to read. I've never seen anyone use this convention (in any repository), so I'm a bit curious what the motivation is?

In my opinion, the reading comprehension difficulty is not in the setup of the for loop, but within the body, so it makes sense to try and maximize readability there instead. E.g., currDim or iterDim would be preferable in my opinion to eachDim, since eachDim is like an incomplete universal quantifier and not an individual "thing". My top preference: if there isn't some other loop-independent dim that we need to disambiguate with, then why not just use dim as the loop variable?

auto eachSizeOfInputTensor = sizesOfInputTensor[eachDimension];
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDimension];

if (eachSizeOfInputTensor == unknownSize ||
eachSizeOfOutputTensor == unknownSize) {
continue;
} else if (eachSizeOfInputTensor == eachSizeOfOutputTensor) {
bjacobgordon marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
bjacobgordon marked this conversation as resolved.
Show resolved Hide resolved

auto scalingIntentErrorMessage =
"unsupported: non-trivial intent to scale dimension: " +
std::to_string(eachDimension);

return rewriter.notifyMatchFailure(binder.op,
scalingIntentErrorMessage);
};

if (antialias != 0) {
return rewriter.notifyMatchFailure(
binder.op,
Expand Down Expand Up @@ -2750,35 +2859,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "unimplemented: cubic coeff must be -0.75");
}

unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
.getSizes()
.size();
auto loc = binder.getLoc();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low-priority: I'm commenting here because we've had a bit of a discussion on loc: I'm certainly amenable to using mlir::Location loc = binder.getLoc(); if you think the explicit typing would be helpful for future devs. That way they can just google mlir Location and it might be easier to figure out what this thing is. The use of auto for loc is pretty widespread in llvm/mlir, but there is a bit of disagreement over which typing convention is preferred, so I'm happy to leave the this up to your discretion.


Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value modeStrValue;

Value scalesValueList = noneVal;
Value sizesValueList = noneVal;
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;
if (mode == "cubic") {
std::string modeStr = "cubic";
if (coordTfMode != "half_pixel")
modeStr = modeStr + "_" + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}

unsigned rankOfInputTensor = sizesOfInputTensor.size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits:

  1. I typically recommend using explicit bitwidths for integers, e.g., uint64_t over unsigned. It has the same number of characters, but is more specific.
  2. There is not much ambiguity in rank, since the input and output tensors should have the same rank. It's fine either way, but I would prefer rank.


// supported modes:
// bilinear (half_pixel), bilinear with align_corners,
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
// (asymmetric), nearest with align_corners, nearest_half_pixel,
// nearest_pytorch_half_pixel
if (mode == "linear") {
std::string modeStr;
switch (rank) {
switch (rankOfInputTensor) {
case 3:
modeStr = "linear";
break;
Expand All @@ -2795,8 +2900,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// mode is apparently half_pixel, NOT pytorch_half_pixel
if (coordTfMode != "half_pixel" && coordTfMode != "align_corners")
modeStr = (modeStr + "_") + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}
if (mode == "nearest") {
std::string modeStr = "nearest";
Expand All @@ -2806,26 +2910,63 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
modeStr = (modeStr + "_") + coordTfMode;
if (nearest_mode != "floor" && nearest_mode != "")
modeStr = modeStr + "," + nearest_mode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}

Value noneVal = rewriter.create<Torch::ConstantNoneOp>(loc);
Value filteredScaleFactorsAsOp = noneVal;
Value filteredSizesAsOp = noneVal;

if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(binder, rewriter, scaleOperand);
sizesValueList = noneVal;
Value proposedScaleFactorsAsOp = operands[2];

// run-time scale factor check for dynamic sizes
for (auto eachDimension : nonResizableDimensions) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Although relatively minor since it is a very small loop, it would be best to avoid an implicit copy constructor as in https://llvm.org/docs/CodingStandards.html#beware-unnecessary-copies-with-auto

auto eachScaleIdentityComparisonAsOp =
scaleIdentityComparisonOpForFactorAtDimensionIn(
proposedScaleFactorsAsOp, eachDimension, binder, rewriter);

auto eachErrorMessage =
"Unsupported: non-trivial scale factor for dimension " +
std::to_string(eachDimension);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachScaleIdentityComparisonAsOp,
rewriter.getStringAttr(eachErrorMessage));
};

filteredScaleFactorsAsOp = withUnsupportedDimensionsFilteredOut(
proposedScaleFactorsAsOp, binder, rewriter);
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(binder, rewriter, sizeOperand);
Value proposedSizesAsOp = operands[3];

// run-time target size check for dynamic sizes
for (auto eachDimension : nonResizableDimensions) {
auto eachSizeComparisonAsOp =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd prefer Value here instead of auto. I think the comment in https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable is a helpful one.

originalSizeComparisonOpForSizeAtDimensionIn(
proposedSizesAsOp, inputTensor, eachDimension, binder,
rewriter);

auto eachErrorMessage =
"Unsupported: non-trivial resizing of dimension " +
std::to_string(eachDimension);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachSizeComparisonAsOp,
rewriter.getStringAttr(eachErrorMessage));
};

filteredSizesAsOp = withUnsupportedDimensionsFilteredOut(
proposedSizesAsOp, binder, rewriter);
}
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
isa<Torch::NoneType>(sizesValueList.getType())) {
if (isa<Torch::NoneType>(filteredScaleFactorsAsOp.getType()) &&
isa<Torch::NoneType>(filteredSizesAsOp.getType())) {
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
}
rewriter
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
binder.op, resultType, operands[0], sizesValueList,
scalesValueList, modeStrValue,
binder.op, typeOfOutputTensor, inputTensor, filteredSizesAsOp,
filteredScaleFactorsAsOp, modeStrValue,
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
/*Torch_BoolType:$antialias*/ cstFalse);
Expand Down Expand Up @@ -3339,7 +3480,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return rewriter.notifyMatchFailure(
binder.op, "supports upto 3d upsampling only");

Value scalesValueList = getValueList(binder, rewriter, scales);
Value scalesValueList =
withUnsupportedDimensionsFilteredOut(scales, binder, rewriter);
if (mode == "linear") {
if (resultRank == 4)
mode = "bilinear";
Expand Down
Loading
Loading