Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a nondet inv op #6

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Format
tzerrell committed Aug 16, 2024
commit eb314b710024d703eb0c6031cf9d5346b31de77c
13 changes: 7 additions & 6 deletions zirgen/Dialect/BigInt/IR/Eval.cpp
Original file line number Diff line number Diff line change
@@ -106,9 +106,10 @@ BytePoly nondetInvMod(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) {
auto lhsInt = toAPInt(lhs);
auto rhsInt = toAPInt(rhs);
size_t maxSize = rhsInt.getBitWidth();
APInt inv(2 * maxSize, 1); // Initialize inverse to zero, twice the width of `prime` to allow multiplication
APInt sqr(lhsInt); // Will be repeatedly squared
APInt position(2 * maxSize, 1); // Bit at `idx` will be 1, other bits will be 0
APInt inv(2 * maxSize,
1); // Initialize inverse to zero, twice the width of `prime` to allow multiplication
APInt sqr(lhsInt); // Will be repeatedly squared
APInt position(2 * maxSize, 1); // Bit at `idx` will be 1, other bits will be 0
sqr = sqr.zext(2 * maxSize);
rhsInt = rhsInt.zext(2 * maxSize);
APInt exp = rhsInt - 2;
@@ -117,10 +118,10 @@ BytePoly nondetInvMod(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) {
// multiply in the current power of n (i.e., n^(2^idx))
inv = (inv * sqr).urem(rhsInt);
}
position <<= 1; // increment the bit position to test in `exp`
sqr = (sqr * sqr).urem(rhsInt); // square `sqr` to increment to `n^(2^(idx+1))`
position <<= 1; // increment the bit position to test in `exp`
sqr = (sqr * sqr).urem(rhsInt); // square `sqr` to increment to `n^(2^(idx+1))`
}
inv = inv.trunc(maxSize); // We don't need the extra space used as multiply buffer
inv = inv.trunc(maxSize); // We don't need the extra space used as multiply buffer
LLVM_DEBUG({ dbgs() << "inv (mod " << rhsInt << "): " << inv << "\n"; });
return fromAPInt(inv, coeffs);
}
6 changes: 3 additions & 3 deletions zirgen/Dialect/BigInt/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -150,9 +150,9 @@ LogicalResult NondetInvModOp::inferReturnTypes(MLIRContext* ctx,
}

LogicalResult ModularInvOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff);
out.push_back(BigIntType::get(ctx,
4 changes: 2 additions & 2 deletions zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp
Original file line number Diff line number Diff line change
@@ -21,8 +21,8 @@ struct ReplaceModularInv : public OpRewritePattern<ModularInvOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ModularInvOp op, PatternRewriter& rewriter) const override {
// Construct the constant 1
mlir::Type oneType = rewriter.getIntegerType(1); // a `1` is bitwidth 1
auto oneAttr = rewriter.getIntegerAttr(oneType, 1); // value 1
mlir::Type oneType = rewriter.getIntegerType(1); // a `1` is bitwidth 1
auto oneAttr = rewriter.getIntegerAttr(oneType, 1); // value 1
auto one = rewriter.create<ConstOp>(op.getLoc(), oneAttr);

auto inv = rewriter.create<NondetInvModOp>(op.getLoc(), op.getLhs(), op.getRhs());
6 changes: 3 additions & 3 deletions zirgen/circuit/bigint/op_tests.cpp
Original file line number Diff line number Diff line change
@@ -130,12 +130,12 @@ void makeReduceTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) {

void makeNondetInvTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) {
auto inp = builder.create<BigInt::DefOp>(loc, bits, 0, true);
auto prime = builder.create<BigInt::DefOp>(loc, bits, 1, true, bits - 1); // TODO: Set to 131 if we need an actual number
auto prime = builder.create<BigInt::DefOp>(loc, bits, 1, true, bits - 1);
auto expected = builder.create<BigInt::DefOp>(loc, bits, 2, true);

// Construct constants
mlir::Type oneType = builder.getIntegerType(1); // a `1` is bitwidth 1
auto oneAttr = builder.getIntegerAttr(oneType, 1); // value 1
mlir::Type oneType = builder.getIntegerType(1); // a `1` is bitwidth 1
auto oneAttr = builder.getIntegerAttr(oneType, 1); // value 1
auto one = builder.create<BigInt::ConstOp>(loc, oneAttr);

auto inv = builder.create<BigInt::NondetInvModOp>(loc, inp, prime);