Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-301: Optimize predicate generation #135

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29,100 changes: 18,402 additions & 10,698 deletions risc0/circuit/rv32im-v2-sys/kernels/cxx/steps.cpp

Large diffs are not rendered by default.

667 changes: 502 additions & 165 deletions risc0/circuit/rv32im-v2-sys/kernels/cxx/steps.h

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions zirgen/Dialect/ZHLT/IR/TypeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ class StructBuilder {

if (memberName == "@super") {
// TODO: Why do we require @super to be first? Document this requirement.
members.insert(members.begin(), {memberName, value.getType()});
members.insert(members.begin(), {.name = memberName, .type = value.getType()});
memberValues.insert(memberValues.begin(), value);
} else {
members.push_back({memberName, value.getType()});
members.push_back({.name = memberName, .type = value.getType()});
memberValues.push_back(value);
}
}
Expand All @@ -65,7 +65,7 @@ class StructBuilder {
assert(value);
assert(value.getType());
assert(memberName != "@super");
members.push_back({memberName, value.getType(), .isPrivate = true});
members.push_back({.name = memberName, .type = value.getType(), .isPrivate = true});
memberValues.push_back(value);
}

Expand Down
1 change: 1 addition & 0 deletions zirgen/Dialect/Zll/Transforms/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cc_library(
"MakePolynomial.cpp",
"MultiplyToIf.cpp",
"PassDetail.h",
"SortForReproducibility.cpp",
"SplitStage.cpp",
],
hdrs = [
Expand Down
1 change: 1 addition & 0 deletions zirgen/Dialect/Zll/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> createAddReductionsPass
std::unique_ptr<mlir::Pass> createIfToMultiplyPass();
std::unique_ptr<mlir::Pass> createMultiplyToIfPass();
std::unique_ptr<mlir::Pass> createBalancedSplitPass(size_t maxOps = 1000);
std::unique_ptr<mlir::Pass> createSortForReproducibilityPass();

// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
Expand Down
11 changes: 11 additions & 0 deletions zirgen/Dialect/Zll/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,15 @@ def BalancedSplit : Pass<"balanced-split"> {
];
}

def SortForReproducibility : Pass<"sort-for-reproducibility"> {
let summary = "Attempt to make operation order consistent between compilers and compilations";
let description = [{
HACK: Evaluation order of arguments is unspecified in c++ and
different compilers do it differently. We want our circuit
compilations to be reproducible, so sortForReproducibility tries
to undo the varation in argument order.
}];
let constructor = "zirgen::Zll::createSortForReproducibilityPass()";
}

#endif // ZLL_TRANSFORM_PASSES
203 changes: 203 additions & 0 deletions zirgen/Dialect/Zll/Transforms/SortForReproducibility.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"

#include "zirgen/Dialect/Zll/IR/IR.h"
#include "zirgen/Dialect/Zll/Transforms/PassDetail.h"
#include "zirgen/Dialect/Zll/Transforms/Passes.h"

using namespace mlir;

namespace zirgen::Zll {

namespace {

struct RemoveEqualZero : public OpRewritePattern<EqualZeroOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(EqualZeroOp op, PatternRewriter& rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};

struct SortForReproducibilityPass : public SortForReproducibilityBase<SortForReproducibilityPass> {
void runOnOperation() override {
StringSet<llvm::BumpPtrAllocator> propStorage;
DenseMap<Operation*, StringRef> propInfo;
DenseMap<Operation*, size_t> argPositions;

// Save some some things we can use to deterministically sort operations.
getOperation()->walk([&](Operation* op) {
for (auto [i, arg] : llvm::enumerate(op->getOperands())) {
auto definer = arg.getDefiningOp();
if (definer) {
auto& pos = argPositions[definer];
pos += i;
}
}
});

std::function<StringRef(Operation*)> getPropInfo = [&](Operation* op) -> StringRef {
auto propIt = propInfo.find(op);
if (propIt != propInfo.end())
return propIt->second;

std::string props = op->getName().getStringRef().str();
llvm::raw_string_ostream os(props);
op->getPropertiesAsAttribute().print(os);

auto [it, didInsert] = propStorage.insert(props);
return propInfo[op] = it->first();
};

std::function<bool(Operation & a, Operation & b)> operationCompare, cachedOperationCompare;
operationCompare = [&](Operation& a, Operation& b) {
if (&a == &b)
return false;

StringRef aInfo = getPropInfo(&a);
StringRef bInfo = getPropInfo(&b);
if (aInfo != bInfo) {
return aInfo < bInfo;
}

size_t aPos = argPositions.lookup(&a);
size_t bPos = argPositions.lookup(&b);
if (aPos != bPos) {
return aPos < bPos;
}

if (a.getNumOperands() != b.getNumOperands()) {
return a.getNumOperands() < b.getNumOperands();
}

// If nothing else works, recursively compare operands
for (auto [aOperand, bOperand] : llvm::zip_equal(a.getOperands(), b.getOperands())) {
auto aDefiner = llvm::dyn_cast<OpResult>(aOperand);
auto bDefiner = llvm::dyn_cast<OpResult>(bOperand);
if (bool(aDefiner) != bool(bDefiner))
return bool(aDefiner) < bool(bDefiner);

if (aDefiner && bDefiner) {
if (aDefiner.getResultNumber() != bDefiner.getResultNumber())
return aDefiner.getResultNumber() < bDefiner.getResultNumber();

if (cachedOperationCompare(*aDefiner.getOwner(), *bDefiner.getOwner()))
return true;
if (cachedOperationCompare(*bDefiner.getOwner(), *aDefiner.getOwner()))
return false;
}

auto aArg = llvm::dyn_cast<BlockArgument>(aOperand);
auto bArg = llvm::dyn_cast<BlockArgument>(bOperand);

if (bool(aArg) != bool(bArg))
return bool(aArg) < bool(bArg);

if (aArg && bArg) {
if (aArg.getArgNumber() != bArg.getArgNumber())
return aArg.getArgNumber() < bArg.getArgNumber();
}
}

return false;
};

llvm::DenseMap<std::pair<Operation*, Operation*>, bool> compareResults;
cachedOperationCompare = [&](Operation& a, Operation& b) {
auto [it, didInsert] = compareResults.try_emplace(std::make_pair(&a, &b), false);
if (didInsert) {
bool lessThan = operationCompare(a, b);
// We can't use the original iterator since it may have been invalidated.
compareResults[std::make_pair(&a, &b)] = lessThan;
return lessThan;
} else {
return it->second;
}
};

getOperation()->walk([&](Block* inner) {
auto shouldSort = [&](Operation* op) {
if (op->hasTrait<OpTrait::IsTerminator>())
return false;
if (isPure(op))
return true;
return false;
};

Block doneBlock;

// Find runs of pure operations all in a row, and sort them.
auto it = inner->getOperations().begin();
while (it != inner->getOperations().end()) {
if (!shouldSort(&*it)) {
++it;
continue;
}

// Move the section of operations we're not sorting into doneBlock.
doneBlock.getOperations().splice(doneBlock.getOperations().end(),
inner->getOperations(),
inner->getOperations().begin(),
it);
it = inner->getOperations().begin();

while (it != inner->getOperations().end() && shouldSort(&*it)) {
++it;
}

Block sortBlock;
// Move the sections of operations we *are* sorting into sortBlock
sortBlock.getOperations().splice(sortBlock.getOperations().end(),
inner->getOperations(),
inner->getOperations().begin(),
it);
sortBlock.getOperations().sort(cachedOperationCompare);

// After we sort reproducibly, we have to topologically sort to make sure we don't screw up
// SSA dominance.
mlir::sortTopologically(&sortBlock);

// Append the sorted section to doneBlock.
doneBlock.getOperations().splice(doneBlock.getOperations().end(),
sortBlock.getOperations(),
sortBlock.getOperations().begin(),
sortBlock.getOperations().end());
it = inner->getOperations().begin();
}

// Move all the operations we've processed out from doneBlock into the original block.
inner->getOperations().splice(inner->getOperations().begin(),
doneBlock.getOperations(),
doneBlock.getOperations().begin(),
doneBlock.getOperations().end());
});
}
};

} // End namespace

std::unique_ptr<Pass> createSortForReproducibilityPass() {
return std::make_unique<SortForReproducibilityPass>();
}

} // namespace zirgen::Zll
Loading
Loading