diff --git a/zirgen/circuit/rv32im/v2/dsl/decode.zir b/zirgen/circuit/rv32im/v2/dsl/decode.zir index 070b0771..96913de2 100644 --- a/zirgen/circuit/rv32im/v2/dsl/decode.zir +++ b/zirgen/circuit/rv32im/v2/dsl/decode.zir @@ -26,7 +26,7 @@ component Decoder(inst: ValU32) { _f3_01 := NondetTwitReg((inst.low & 0x3000) / 0x1000); _rd_34 := NondetTwitReg((inst.low & 0x0C00) / 0x0400); _rd_12 := NondetTwitReg((inst.low & 0x0300) / 0x0100); - _rd_0 := NondetTwitReg((inst.low & 0x0080) / 0x0080); + _rd_0 := NondetBitReg((inst.low & 0x0080) / 0x0080); // The opcode is special and is unconstrained. // This implies the for the decoding to be fully correct, some later diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_div.zir b/zirgen/circuit/rv32im/v2/dsl/inst_div.zir index afddc0b9..5f0cb9a1 100644 --- a/zirgen/circuit/rv32im/v2/dsl/inst_div.zir +++ b/zirgen/circuit/rv32im/v2/dsl/inst_div.zir @@ -57,13 +57,22 @@ component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) { rem := ValU32(rem_low, rem_high); // Either all signed, or nothing signed settings := MultiplySettings(signed, signed, signed); - // Do the acuumulate + // Do the accumulate mul := MultiplyAccumulate(quot, denom, rem, settings); // Check the main result (numer = quot * denom + rem AssertEqU32(mul.outLow, numer); // The top bits should all be 0 or all be 1 topBitType := NondetBitReg(1 - Isz(mul.outHigh.low)); AssertEqU32(mul.outHigh, ValU32(0xffff * topBitType, 0xffff * topBitType)); + // Check if denom is zero + isZero := IsZero(denom.low + denom.high); + // If non-zero, make sure 0 <= rem < denom + if (isZero) { + AssertEqU32(rem, numer); + } else { + cmp := CmpLessThanUnsigned(rem, denom); + cmp.is_less_than = 1; + }; DivideReturn(quot, rem) } diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir b/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir index 8de1eeee..053f7753 100644 --- a/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir +++ b/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir @@ -54,8 +54,9 @@ component ECallTerminate(cycle: Reg, input: InstInput) { component DecomposeLow2(len: Val) { // We split len into a multiple of 4, and the low 2 bits as one hot - public high := NondetReg((len & 0xfffc) / 4); + public high := NondetU16Reg((len & 0xfffc) / 4); public low2 := NondetReg(len & 0x3); + len = 4*high + low2; public low2Hot := OneHot<4>(low2); public highZero := IsZero(high); public isZero := Reg(highZero * low2Hot[0]); diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_p2.zir b/zirgen/circuit/rv32im/v2/dsl/inst_p2.zir index 0a8e05da..637b69b0 100644 --- a/zirgen/circuit/rv32im/v2/dsl/inst_p2.zir +++ b/zirgen/circuit/rv32im/v2/dsl/inst_p2.zir @@ -276,12 +276,32 @@ component PoseidonCheckOut(cycle: Reg, prev: PoseidonState) { PoseidonState(GetDef(prev), nextState, 0, 0, 0, prev.mode, prev.inner, MakeExt(0)) } +component FieldToWord(val: Val) { + // Decompose a field element into two u16s + public low := NondetU16Reg(val & 0xffff); + public high := U16Reg((val - low) / 65536); + // Check decomposition is unique + // If low == 0, high must be < 30720, otherwise high must be <= 30719 + // Guess if low is zero + lowIsZero := NondetBitReg(Isz(low)); + // Now check results: Technically, prover could set low-is-zero to false even if + // low was zero, but this only results in a stricter check of high, so it's pointless + if (lowIsZero) { + low = 0; + U16Reg(30720 - high); + } else { + U16Reg(30719 - high); + }; + // Return as u32 + public ret := ValU32(low, high); +} + component PoseidonStoreOut(cycle: Reg, prev: PoseidonState) { for i : 0..8 { - val := prev.inner[i]; - low := NondetU16Reg(val & 0xffff); - high := U16Reg((val - low) / 65536); - MemoryWrite(cycle, prev.bufOutAddr + i, ValU32(low, high)); + ftw := FieldToWord(prev.inner[i]); + mw := MemoryWrite(cycle, prev.bufOutAddr + i, ftw.ret); + AliasLayout!(mw.io.newTxn.dataLow, ftw.low.arg.val); + AliasLayout!(mw.io.newTxn.dataHigh, ftw.high.arg.val); }; isNormal := IsZero(prev.loadTxType - TxKindRead()); outState := isNormal * StateDecode() + (1 - isNormal) * StatePoseidonPaging(); @@ -298,10 +318,10 @@ component PoseidonDoOut(cycle: Reg, prev: PoseidonState) { component PoseidonStoreState(cycle: Reg, prev: PoseidonState) { for i : 0..8 { - val := prev.inner[16 + i]; - low := NondetU16Reg(val & 0xffff); - high := U16Reg((val - low) / 65536); - MemoryWrite(cycle, prev.stateAddr+ i, ValU32(low, high)); + ftw := FieldToWord(prev.inner[16 + i]); + mw := MemoryWrite(cycle, prev.stateAddr+ i, ftw.ret); + AliasLayout!(mw.io.newTxn.dataLow, ftw.low.arg.val); + AliasLayout!(mw.io.newTxn.dataHigh, ftw.high.arg.val); }; PoseidonState(GetDef(prev), StateDecode(), 0, 0, 0, prev.mode, prev.inner, MakeExt(0)) } diff --git a/zirgen/circuit/rv32im/v2/dsl/lookups.zir b/zirgen/circuit/rv32im/v2/dsl/lookups.zir index 7f5065fd..98c09abd 100644 --- a/zirgen/circuit/rv32im/v2/dsl/lookups.zir +++ b/zirgen/circuit/rv32im/v2/dsl/lookups.zir @@ -21,7 +21,7 @@ component NondetU8Reg(val: Val) { component U8Reg(val: Val) { ret := NondetU8Reg(val); ret = val; - val + ret } argument ArgU16(count: Val, val: Val) { @@ -32,7 +32,7 @@ argument ArgU16(count: Val, val: Val) { // Set a register nodeterministically, and then verify it is a U16 component NondetU16Reg(val: Val) { - arg := ArgU16(1, val); + public arg := ArgU16(1, val); arg.count = 1; arg.val } @@ -41,7 +41,7 @@ component NondetU16Reg(val: Val) { component U16Reg(val: Val) { ret := NondetU16Reg(val); ret = val; - val + ret } // TODO: Tests diff --git a/zirgen/circuit/rv32im/v2/dsl/mem.zir b/zirgen/circuit/rv32im/v2/dsl/mem.zir index b9729a8f..a5b07963 100644 --- a/zirgen/circuit/rv32im/v2/dsl/mem.zir +++ b/zirgen/circuit/rv32im/v2/dsl/mem.zir @@ -93,7 +93,7 @@ component MemoryRead(cycle: Reg, addr: Val) { // A normal memory write component MemoryWrite(cycle: Reg, addr: Val, data: ValU32) { - io := MemoryIO(2*cycle + 1, addr); + public io := MemoryIO(2*cycle + 1, addr); IsForward(io); io.newTxn.dataLow = data.low; io.newTxn.dataHigh = data.high;