Skip to content

Commit

Permalink
[core] fix a couple of bugs with loop normalization and unrolling (#2618
Browse files Browse the repository at this point in the history
)

* Remove the change to the invariant loop test.

Add const to LoopComponents methods.

When loops are already in what appears to be normalized form,
loop normalization takes no action. It is therefore necessary
for loop unrolling to always compute the number of iterations
that may be executed for the loop. This adds the algorithm
used to compute the number of iterations to loop unrolling
so that loops that do not have iterations at all are properly
detected.

Add tests for always true and always false extreme conditions.

Patch LoopNormalize, now in a new file.

Finish merging in the changes.

Signed-off-by: Eric Schweitz <[email protected]>

* Use switches throughout.

Signed-off-by: Eric Schweitz <[email protected]>

---------

Signed-off-by: Eric Schweitz <[email protected]>
  • Loading branch information
schweitzpgi authored Feb 15, 2025
1 parent 3994807 commit 76ffe23
Show file tree
Hide file tree
Showing 12 changed files with 695 additions and 139 deletions.
2 changes: 1 addition & 1 deletion lib/Optimizer/Transforms/LiftArrayAllocPatterns.inc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*******************************************************************************
/****************************************************************-*- C++ -*-****
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
Expand Down
304 changes: 289 additions & 15 deletions lib/Optimizer/Transforms/LoopAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
******************************************************************************/

#include "LoopAnalysis.h"
#include "cudaq/Optimizer/Builder/Factory.h"
#include "mlir/IR/Dominance.h"

using namespace mlir;
Expand Down Expand Up @@ -73,14 +74,6 @@ static bool isaConstantOf(Value v, std::int64_t hasVal) {
return false;
}

static bool isNegativeConstant(Value v) {
v = peelCastOps(v);
if (auto c = v.getDefiningOp<arith::ConstantOp>())
if (auto ia = dyn_cast<IntegerAttr>(c.getValue()))
return ia.getInt() < 0;
return false;
}

static bool isClosedIntervalForm(arith::CmpIPredicate p) {
return p == arith::CmpIPredicate::ule || p == arith::CmpIPredicate::sle;
}
Expand Down Expand Up @@ -210,6 +203,10 @@ static BlockArgument getLinearExpr(Value expr,
return scaledIteration(expr);
}

static unsigned bitWidth(Value val) {
return cast<IntegerType>(val.getType()).getWidth();
}

namespace cudaq {

bool opt::isSemiOpenPredicate(arith::CmpIPredicate p) {
Expand All @@ -223,6 +220,11 @@ bool opt::isUnsignedPredicate(arith::CmpIPredicate p) {
p == arith::CmpIPredicate::ugt || p == arith::CmpIPredicate::uge;
}

bool opt::isSignedPredicate(arith::CmpIPredicate p) {
return p == arith::CmpIPredicate::slt || p == arith::CmpIPredicate::sle ||
p == arith::CmpIPredicate::sgt || p == arith::CmpIPredicate::sge;
}

// We expect the loop control value to have the following form.
//
// %final = cc.loop while ((%iter = %initial) -> (iN)) {
Expand Down Expand Up @@ -282,7 +284,7 @@ bool opt::isaMonotonicLoop(Operation *op, bool allowEarlyExit,

bool opt::isaInvariantLoop(const LoopComponents &c, bool allowClosedInterval) {
if (isaConstantOf(c.initialValue, 0) && isaConstantOf(c.stepValue, 1) &&
isa<arith::AddIOp>(c.stepOp) && !isNegativeConstant(c.compareValue)) {
isa<arith::AddIOp>(c.stepOp)) {
auto cmp = cast<arith::CmpIOp>(c.compareOp);
return validCountedLoopIntervalForm(cmp, allowClosedInterval);
}
Expand Down Expand Up @@ -314,26 +316,296 @@ bool opt::isaConstantUpperBoundLoop(cc::LoopOp loop, bool allowClosedInterval) {
isaConstant(c.compareValue);
}

Value opt::LoopComponents::getCompareInduction() {
Value opt::LoopComponents::getCompareInduction() const {
auto cmpOp = cast<arith::CmpIOp>(compareOp);
return cmpOp.getLhs() == compareValue ? cmpOp.getRhs() : cmpOp.getLhs();
}

bool opt::LoopComponents::stepIsAnAddOp() { return isa<arith::AddIOp>(stepOp); }
bool opt::LoopComponents::stepIsAnAddOp() const {
return isa<arith::AddIOp>(stepOp);
}

bool opt::LoopComponents::shouldCommuteStepOp() {
bool opt::LoopComponents::shouldCommuteStepOp() const {
if (auto addOp = dyn_cast_or_null<arith::AddIOp>(stepOp))
return addOp.getRhs() == stepRegion->front().getArgument(induction);
// Note: we don't allow induction on lhs of subtraction.
return false;
}

bool opt::LoopComponents::isClosedIntervalForm() {
bool opt::LoopComponents::isClosedIntervalForm() const {
auto cmp = cast<arith::CmpIOp>(compareOp);
return ::isClosedIntervalForm(cmp.getPredicate());
}

bool opt::LoopComponents::isLinearExpr() { return addendValue || scaleValue; }
bool opt::LoopComponents::isLinearExpr() const {
return addendValue || scaleValue;
}

std::int64_t opt::LoopComponents::extendValue(unsigned width,
std::size_t val) const {
const bool signExt =
isSignedPredicate(cast<arith::CmpIOp>(compareOp).getPredicate());
std::int64_t result = val;
switch (width) {
case 8:
if (signExt) {
std::int8_t v = val & 0xFF;
result = v;
} else {
std::uint8_t v = val & 0xFF;
result = v;
}
break;
case 16:
if (signExt) {
std::int16_t v = val & 0xFFFF;
result = v;
} else {
std::uint16_t v = val & 0xFFFF;
result = v;
}
break;
case 32:
if (signExt) {
std::int32_t v = val & 0xFFFFFFFF;
result = v;
} else {
std::uint32_t v = val & 0xFFFFFFFF;
result = v;
}
break;
default:
break;
}
return result;
}

bool opt::LoopComponents::hasAlwaysTrueCondition() const {
auto cmpValOpt = factory::maybeValueOfIntConstant(compareValue);
if (!cmpValOpt)
return false;
auto width = bitWidth(compareValue);
std::int64_t cmpVal = *cmpValOpt;
auto pred = cast<arith::CmpIOp>(compareOp).getPredicate();
switch (width) {
case 8: {
switch (pred) {
case arith::CmpIPredicate::sge:
return static_cast<std::int8_t>(cmpVal) ==
std::numeric_limits<std::int8_t>::min();
case arith::CmpIPredicate::sle:
return static_cast<std::int8_t>(cmpVal) ==
std::numeric_limits<std::int8_t>::max();
case arith::CmpIPredicate::uge:
return static_cast<std::uint8_t>(cmpVal) ==
std::numeric_limits<std::uint8_t>::min();
case arith::CmpIPredicate::ule:
return static_cast<std::uint8_t>(cmpVal) ==
std::numeric_limits<std::uint8_t>::max();
default:
break;
}
} break;
case 16: {
switch (pred) {
case arith::CmpIPredicate::sge:
return static_cast<std::int16_t>(cmpVal) ==
std::numeric_limits<std::int16_t>::min();
case arith::CmpIPredicate::sle:
return static_cast<std::int16_t>(cmpVal) ==
std::numeric_limits<std::int16_t>::max();
case arith::CmpIPredicate::uge:
return static_cast<std::uint16_t>(cmpVal) ==
std::numeric_limits<std::uint16_t>::min();
case arith::CmpIPredicate::ule:
return static_cast<std::uint16_t>(cmpVal) ==
std::numeric_limits<std::uint16_t>::max();
default:
break;
}
} break;
case 32: {
switch (pred) {
case arith::CmpIPredicate::sge:
return static_cast<std::int32_t>(cmpVal) ==
std::numeric_limits<std::int32_t>::min();
case arith::CmpIPredicate::sle:
return static_cast<std::int32_t>(cmpVal) ==
std::numeric_limits<std::int32_t>::max();
case arith::CmpIPredicate::uge:
return static_cast<std::uint32_t>(cmpVal) ==
std::numeric_limits<std::uint32_t>::min();
case arith::CmpIPredicate::ule:
return static_cast<std::uint32_t>(cmpVal) ==
std::numeric_limits<std::uint32_t>::max();
default:
break;
}
} break;
case 64: {
switch (pred) {
case arith::CmpIPredicate::sge:
return static_cast<std::int64_t>(cmpVal) ==
std::numeric_limits<std::int64_t>::min();
case arith::CmpIPredicate::sle:
return static_cast<std::int64_t>(cmpVal) ==
std::numeric_limits<std::int64_t>::max();
case arith::CmpIPredicate::uge:
return static_cast<std::uint64_t>(cmpVal) ==
std::numeric_limits<std::uint64_t>::min();
case arith::CmpIPredicate::ule:
return static_cast<std::uint64_t>(cmpVal) ==
std::numeric_limits<std::uint64_t>::max();
default:
break;
}
} break;
default:
break;
}
return false;
}

bool opt::LoopComponents::hasAlwaysFalseCondition() const {
auto cmpValOpt = factory::maybeValueOfIntConstant(compareValue);
if (!cmpValOpt)
return false;
auto width = bitWidth(compareValue);
std::int64_t cmpVal = *cmpValOpt;
auto pred = cast<arith::CmpIOp>(compareOp).getPredicate();
switch (width) {
case 8: {
switch (pred) {
case arith::CmpIPredicate::slt:
return static_cast<std::int8_t>(cmpVal) ==
std::numeric_limits<std::int8_t>::min();
case arith::CmpIPredicate::sgt:
return static_cast<std::int8_t>(cmpVal) ==
std::numeric_limits<std::int8_t>::max();
case arith::CmpIPredicate::ult:
return static_cast<std::uint8_t>(cmpVal) ==
std::numeric_limits<std::uint8_t>::min();
case arith::CmpIPredicate::ugt:
return static_cast<std::uint8_t>(cmpVal) ==
std::numeric_limits<std::uint8_t>::max();
default:
break;
}
} break;
case 16: {
switch (pred) {
case arith::CmpIPredicate::slt:
return static_cast<std::int16_t>(cmpVal) ==
std::numeric_limits<std::int16_t>::min();
case arith::CmpIPredicate::sgt:
return static_cast<std::int16_t>(cmpVal) ==
std::numeric_limits<std::int16_t>::max();
case arith::CmpIPredicate::ult:
return static_cast<std::uint16_t>(cmpVal) ==
std::numeric_limits<std::uint16_t>::min();
case arith::CmpIPredicate::ugt:
return static_cast<std::uint16_t>(cmpVal) ==
std::numeric_limits<std::uint16_t>::max();
default:
break;
}
} break;
case 32: {
switch (pred) {
case arith::CmpIPredicate::slt:
return static_cast<std::int32_t>(cmpVal) ==
std::numeric_limits<std::int32_t>::min();
case arith::CmpIPredicate::sgt:
return static_cast<std::int32_t>(cmpVal) ==
std::numeric_limits<std::int32_t>::max();
case arith::CmpIPredicate::ult:
return static_cast<std::uint32_t>(cmpVal) ==
std::numeric_limits<std::uint32_t>::min();
case arith::CmpIPredicate::ugt:
return static_cast<std::uint32_t>(cmpVal) ==
std::numeric_limits<std::uint32_t>::max();
default:
break;
}
} break;
case 64: {
switch (pred) {
case arith::CmpIPredicate::slt:
return static_cast<std::int64_t>(cmpVal) ==
std::numeric_limits<std::int64_t>::min();
case arith::CmpIPredicate::sgt:
return static_cast<std::int64_t>(cmpVal) ==
std::numeric_limits<std::int64_t>::max();
case arith::CmpIPredicate::ult:
return static_cast<std::uint64_t>(cmpVal) ==
std::numeric_limits<std::uint64_t>::min();
case arith::CmpIPredicate::ugt:
return static_cast<std::uint64_t>(cmpVal) ==
std::numeric_limits<std::uint64_t>::max();
default:
break;
}
} break;
default:
break;
}
return false;
}

std::optional<std::size_t> opt::LoopComponents::getIterationsConstant() const {
auto initValOpt = factory::maybeValueOfIntConstant(initialValue);
if (!initValOpt)
return std::nullopt;
std::int64_t initVal = extendValue(bitWidth(initialValue), *initValOpt);
auto endValOpt = factory::maybeValueOfIntConstant(compareValue);
if (!endValOpt)
return std::nullopt;
std::int64_t endVal = extendValue(bitWidth(compareValue), *endValOpt);
auto stepValOpt = factory::maybeValueOfIntConstant(stepValue);
if (!stepValOpt)
return std::nullopt;
std::int64_t stepVal = extendValue(bitWidth(stepValue), *stepValOpt);
if (!stepIsAnAddOp())
stepVal = -stepVal;
if (isLinearExpr()) {
if (addendValue) {
auto addendOpt = factory::maybeValueOfIntConstant(addendValue);
if (!addendOpt)
return std::nullopt;
std::int64_t addend = extendValue(bitWidth(addendValue), *addendOpt);
if (negatedAddend)
endVal += addend;
else
endVal -= addend;
}
if (minusOneMult) {
initVal = -initVal;
stepVal = -stepVal;
}
if (scaleValue) {
auto scaleValOpt = factory::maybeValueOfIntConstant(scaleValue);
if (!scaleValOpt)
return std::nullopt;
std::int64_t scaleVal = extendValue(bitWidth(scaleValue), *scaleValOpt);
if (reciprocalScale) {
endVal *= scaleVal;
} else {
endVal *= scaleVal;
stepVal *= scaleVal;
}
}
}
if (!isClosedIntervalForm()) {
if (stepVal < 0)
endVal += 1;
else
endVal -= 1;
}
std::int64_t result = (endVal - initVal + stepVal) / stepVal;
if (result < 0)
result = 0;
return {result};
}

template <typename T>
constexpr int computeArgsOffset() {
Expand All @@ -350,7 +622,9 @@ std::optional<opt::LoopComponents> opt::getLoopComponents(cc::LoopOp loop) {
auto &whileEntry = whileRegion.front();
auto condOp = cast<cc::ConditionOp>(whileRegion.back().back());
result.compareOp = condOp.getCondition().getDefiningOp();
auto cmpOp = cast<arith::CmpIOp>(result.compareOp);
auto cmpOp = dyn_cast<arith::CmpIOp>(result.compareOp);
if (!cmpOp)
return {};

auto argumentToCompare = [&](unsigned idx) -> bool {
return (getLinearExpr(cmpOp.getLhs(), result, loop) ==
Expand Down
Loading

0 comments on commit 76ffe23

Please sign in to comment.