Skip to content

Commit

Permalink
fix(ONNX): avoids resizing conventionally fixed dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacobgordon committed Jan 8, 2025
1 parent 7d8d04b commit 6baa8d5
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class BaseTensorType : public Type {
/// Enable isa/dyn_cast for BaseTensorType.
static bool classof(Type type);

/// The element-wise comparison of each dimension/size in `that` tensor
ArrayRef<std::optional<bool>>
dimensionComparisonsAgainst(BaseTensorType that) const;

/// Return true if this type has the same sizes and dtype as the other.
bool hasSameSizesAndDtype(BaseTensorType other) const;

Expand Down
27 changes: 24 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,30 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();

Value inputTensor = operands[0];
Torch::ValueTensorType inputTensor_shape =
cast<Torch::ValueTensorType>(inputTensor.getType());

ArrayRef<std::optional<bool>> dimensionComparisons =
inputTensor_shape.dimensionComparisonsAgainst(outputTensor_shape);

// Comparisons of the dimensions assumed to carry the batch and channel
auto fixedDimensionComparisons = dimensionComparisons.take_front(2);

for (auto eachComparison : fixedDimensionComparisons) {
if (eachComparison == nullptr) {
return rewriter.notifyMatchFailure(
binder.op, "Sizes for batch and channel dimensions must be "
"statically defined");
}
if (eachComparison == false) {
return rewriter.notifyMatchFailure(
binder.op,
"Unexpected intent to resize the batch/channel dimensions");
}
};

if (antialias != 0) {
return rewriter.notifyMatchFailure(
binder.op,
Expand Down Expand Up @@ -2749,9 +2773,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "unimplemented: cubic coeff must be -0.75");
}

Value inputTensor = operands[0];
Torch::ValueTensorType inputTensor_shape =
cast<Torch::ValueTensorType>(inputTensor.getType());
ArrayRef<int64_t> inputTensor_dimensions = inputTensor_shape.getSizes();
unsigned rank = inputTensor_dimensions.size();

Expand Down
25 changes: 25 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,31 @@ static bool isValidTorchDtype(Type dtype) {
return false;
}

ArrayRef<std::optional<bool>>
BaseTensorType::dimensionComparisonsAgainst(BaseTensorType that) const {
auto this_dimensions = /**/ getSizes();
auto that_dimensions = that.getSizes();

auto this_rank = this_dimensions.size();
auto that_rank = that_dimensions.size();
assert((this_rank == that_rank) && "Ranks must match to compare dimensions");

SmallVector<std::optional<bool>> dimensionComparisons;
dimensionComparisons.reserve(this_rank);

auto dimensionPairs = llvm::zip(this_dimensions, that_dimensions);

for (auto [eachPair_lhs, eachPair_rhs] : dimensionPairs) {
if (eachPair_lhs == kUnknownSize || eachPair_rhs == kUnknownSize) {
dimensionComparisons.push_back(nullptr);
} else {
dimensionComparisons.push_back(eachPair_lhs == eachPair_rhs);
}
}

return dimensionComparisons;
}

bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
return getOptionalSizes() == other.getOptionalSizes() &&
getOptionalDtype() == other.getOptionalDtype();
Expand Down

0 comments on commit 6baa8d5

Please sign in to comment.