diff --git a/zirgen/circuit/rv32im/v2/dsl/arr.zir b/zirgen/circuit/rv32im/v2/dsl/arr.zir new file mode 100644 index 00000000..e8dae002 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/dsl/arr.zir @@ -0,0 +1,45 @@ +// RUN: zirgen --test %s + +// Vector / List functions + +// Shifts + Rotates +component RotateLeft(in: Array, n: Val) { + for i : 0..SIZE { + if (InRange(0, i - n, SIZE)) { in[i - n] } else { in[SIZE + i - n] } + } +} + +component RotateRight(in: Array, n: Val) { + for i : 0..SIZE { + if (InRange(0, i + n, SIZE)) { in[i + n] } else { in[i + n - SIZE] } + } +} + +component ShiftLeft(in: Array, n: Val) { + for i : 0..SIZE { + if (InRange(0, i - n, SIZE)) { in[i - n] } else { 0 } + } +} + +component ShiftRight(in: Array, n: Val) { + for i : 0..SIZE { + if (InRange(0, i + n, SIZE)) { in[i + n] } else { 0 } + } +} + +component EqArr(a: Array, b: Array) { + for i : 0..SIZE { + a[i] = b[i]; + } +} + +// Tests.... + +test ShiftAndRotate { + // Remember: array entry 0 is the low bit, so there seem backwards + EqArr<8>(ShiftRight<8>([3, 1, 5, 0, 2, 0, 0, 0], 2), [5, 0, 2, 0, 0, 0, 0, 0]); + EqArr<8>(ShiftLeft<8>([1, 4, 2, 0, 6, 0, 0, 0], 2), [0, 0, 1, 4, 2, 0, 6, 0]); + EqArr<8>(RotateRight<8>([7, 6, 1, 0, 2, 0, 0, 0], 2), [1, 0, 2, 0, 0, 0, 7, 6]); + EqArr<8>(RotateLeft<8>([4, 5, 1, 0, 1, 0, 0, 3], 2), [0, 3, 4, 5, 1, 0, 1, 0]); +} + diff --git a/zirgen/circuit/rv32im/v2/dsl/bits.zir b/zirgen/circuit/rv32im/v2/dsl/bits.zir index 60babe42..dcf9231d 100644 --- a/zirgen/circuit/rv32im/v2/dsl/bits.zir +++ b/zirgen/circuit/rv32im/v2/dsl/bits.zir @@ -35,12 +35,17 @@ function AssertTwit(val: Val) { val * (1 - val) * (2 - val) * (3 - val) = 0; } +// Simple bit ops component BitAnd(a: Val, b: Val) { - Reg(a * b) + a * b } component BitOr(a: Val, b: Val) { - Reg(1 - (1 - a) * (1 - b)) + a + b - a * b +} + +component BitXor(a: Val, b: Val) { + a + b - 2 * a * b } // Set a register nodeterministically, and then verify it is a twit @@ -81,4 +86,3 @@ test TwitInRange{ test_fails TwitOutOfRange { AssertTwit(4); } - diff --git a/zirgen/circuit/rv32im/v2/dsl/consts.zir b/zirgen/circuit/rv32im/v2/dsl/consts.zir index 4d78ba6e..be1aedfd 100644 --- a/zirgen/circuit/rv32im/v2/dsl/consts.zir +++ b/zirgen/circuit/rv32im/v2/dsl/consts.zir @@ -39,11 +39,18 @@ component StatePoseidonStoreState() { 23 } component StatePoseidonExtRound() { 24 } component StatePoseidonIntRounds() { 25 } -component StateDecode() { 32 } +component StateShaEcall() { 32 } +component StateShaLoadState() { 33 } +component StateShaLoadData() { 34 } +component StateShaMix() { 35 } +component StateShaStoreState() { 36 } + +component StateDecode() { 40 } component RegA0() { 10 } component RegA1() { 11 } component RegA2() { 12 } component RegA3() { 13 } +component RegA4() { 14 } component RegA7() { 17 } diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_div.zir b/zirgen/circuit/rv32im/v2/dsl/inst_div.zir index 5f0cb9a1..c71a70c6 100644 --- a/zirgen/circuit/rv32im/v2/dsl/inst_div.zir +++ b/zirgen/circuit/rv32im/v2/dsl/inst_div.zir @@ -59,19 +59,36 @@ component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) { settings := MultiplySettings(signed, signed, signed); // Do the accumulate mul := MultiplyAccumulate(quot, denom, rem, settings); - // Check the main result (numer = quot * denom + rem + // Check the main result (numer = quot * denom + rem) AssertEqU32(mul.outLow, numer); // The top bits should all be 0 or all be 1 topBitType := NondetBitReg(1 - Isz(mul.outHigh.low)); AssertEqU32(mul.outHigh, ValU32(0xffff * topBitType, 0xffff * topBitType)); // Check if denom is zero isZero := IsZero(denom.low + denom.high); + // Get top bit of numerator + topNum := NondetBitReg((numer.high & 0x8000) / 0x8000); + // Verify we got it right + U16Reg((numer.high - 0x8000 * topNum) * 2); + numNeg := topNum * signed; + // Get the absolute value of the denominator + denomNeg := mul.bNeg; + denomAbs := NormalizeU32(DenormedValU32( + denomNeg * (0x10000 - denom.low) + (1 - denomNeg) * denom.low, + denomNeg * (0xffff - denom.high) + (1 - denomNeg) * denom.high + )); + // Flip the sign of the remainder if the numerator is negative + remNormal := NormalizeU32(DenormedValU32( + numNeg * (0x10000 - rem.low) + (1 - numNeg) * rem.low, + numNeg * (0xffff - rem.high) + (1 - numNeg) * rem.high + )); + // Decide if we need to swap order of // If non-zero, make sure 0 <= rem < denom if (isZero) { AssertEqU32(rem, numer); } else { - cmp := CmpLessThanUnsigned(rem, denom); - cmp.is_less_than = 1; + lt := CmpLessThanUnsigned(remNormal, denomAbs); + lt.is_less_than = 1; }; DivideReturn(quot, rem) } diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir b/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir index 053f7753..54ce2df7 100644 --- a/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir +++ b/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir @@ -3,7 +3,6 @@ import inst; import consts; - // Prepare to read a certain length, maybe return a smaller one extern HostReadPrepare(fd: Val, len: Val): Val; @@ -30,12 +29,13 @@ component MachineECall(cycle: Reg, input: InstInput, pc_addr: Val) { input.mode = 1; dispatch_idx := MemoryRead(cycle, MachineRegBase() + RegA7()); dispatch_idx.high = 0; - dispatch := OneHot<4>(dispatch_idx.low); + dispatch := OneHot<5>(dispatch_idx.low); state := dispatch -> ( StateTerminate(), StateHostReadSetup(), StateHostWrite(), - StatePoseidonEcall() + StatePoseidonEcall(), + StateShaEcall() ); ECallOutput(state, 0, 0, 0) } @@ -86,8 +86,8 @@ component ECallHostReadSetup(cycle: Reg, input: InstInput) { lenDecomp := DecomposeLow2(newLen); // Check if length is exactly 1, 2, or 3 len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero); - // Check if things are 'uneven' - uneven := Reg(len123 * ptrDecomp.low2Nonzero); + // Check if things are 'uneven' (this is an 'or') + uneven := Reg(len123 + ptrDecomp.low2Nonzero - len123 * ptrDecomp.low2Nonzero); // Now pick the next cycle nextCycle := // If length == 0, go back to decoding @@ -117,11 +117,54 @@ component ECallHostWrite(cycle: Reg, input: InstInput) { ECallOutput(StateDecode(), 0, 0, 0) } -component ECallHostReadBytes(cycle: Reg, input: InstInput) { - // TODO +component ECallHostReadBytes(cycle: Reg, input: InstInput, ptrWord: Val, ptrLow2: Val, len: Val) { input.state = StateHostReadBytes(); - 0 = 1; - ECallOutput(16, 0, 0, 0) + // Decompose next len + lenDecomp := DecomposeLow2(len - 1); + // Check if length is exactly 1, 2, or 3 + len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero); + // Check is next pointer is even (this can only happen if Low2 == 3) + nextPtrEven := IsZero(ptrLow2 - 3); + nextPtrUneven := 1 - nextPtrEven; + nextPtrWord := nextPtrEven * (ptrWord + 1) + nextPtrUneven * ptrWord; + nextPtrLow2 := nextPtrUneven * (ptrLow2 + 1); + // Check if things are 'uneven' (this is an 'or') + uneven := Reg(len123 + nextPtrUneven - len123 * nextPtrUneven); + // Check is length is exactly zero + lenZero := IsZero(len - 1); + // Split low bits into parts + low0 := NondetBitReg(ptrLow2 & 1); + low1 := BitReg((ptrLow2 - low0) / 2); + // Load the original word + origWord := MemoryRead(cycle, ptrWord); + // Write the answer + io := MemoryWriteUnconstrained(cycle, ptrWord).io; + // Make the non-specified half matches + if (low1) { + origWord.low = io.newTxn.dataLow; + } else { + origWord.high = io.newTxn.dataHigh; + }; + // Get the half that changed + oldHalf := low1 * origWord.high + (1 - low1) * origWord.low; + newHalf := low1 * io.newTxn.dataHigh + (1 - low1) * io.newTxn.dataLow; + // Split both into bytes + oldBytes := SplitWord(oldHalf); + newBytes := SplitWord(newHalf); + // Make sure the non-specified bytes matchs + if (low0) { + oldBytes.byte0 = newBytes.byte0; + } else { + oldBytes.byte1 = newBytes.byte1; + }; + nextCycle := + // If length == 0, go back to decoding + lenZero * StateDecode() + + // If length != 0 and uneven, do bytes + (1 - lenZero) * uneven * StateHostReadBytes() + + // If length != 0 and even, more words + (1 - lenZero) * (1 - uneven) * StateHostReadWords(); + ECallOutput(nextCycle, nextPtrWord, nextPtrLow2, len - 1) } component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Val) { @@ -129,9 +172,9 @@ component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Va lenDecomp := DecomposeLow2(len); wordsDecomp := DecomposeLow2(lenDecomp.high); doWord := [ - wordsDecomp.low2Hot[1] * wordsDecomp.highZero, - wordsDecomp.low2Hot[2] * wordsDecomp.highZero, - wordsDecomp.low2Hot[3] * wordsDecomp.highZero, + (wordsDecomp.low2Hot[1] + wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero), + (wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3])* wordsDecomp.highZero + (1 - wordsDecomp.highZero), + (wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero), (1 - wordsDecomp.highZero) ]; count := reduce doWord init 0 with Add; @@ -139,14 +182,15 @@ component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Va addr := Reg(doWord[i] * (ptrWord + i) + (1 - doWord[i]) * SafeWriteWord()); MemoryWriteUnconstrained(cycle, addr); }; - lenZero := IsZero(len - 4 * count); + newLenHighZero := IsZero(lenDecomp.high - count); + lenZero := Reg(newLenHighZero * (1 - lenDecomp.low2Nonzero)); nextCycle := // If length == 0, go back to decoding lenZero * StateDecode() + // If length != 0 and uneven, do bytes - (1 - lenZero) * (lenDecomp.low2Nonzero) * StateHostReadBytes() + - // If lengtj != 0 and even, more words - (1 - lenZero) * (1 - lenDecomp.low2Nonzero) * StateHostReadWords(); + (1 - lenZero) * newLenHighZero * StateHostReadBytes() + + // If length != 0 and even, more words + (1 - lenZero) * (1 - newLenHighZero) * StateHostReadWords(); ECallOutput(nextCycle, ptrWord + count, 0, len - count * 4) } @@ -162,7 +206,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) { ECallTerminate(cycle, inst_input), ECallHostReadSetup(cycle, inst_input), ECallHostWrite(cycle, inst_input), - ECallHostReadBytes(cycle, inst_input), + ECallHostReadBytes(cycle, inst_input, s0@1, s1@1, s2@1), ECallHostReadWords(cycle, inst_input, s0@1, s2@1), IllegalECall(), IllegalECall() @@ -172,6 +216,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) { s2 := Reg(output.s2); isDecode := IsZero(output.state - StateDecode()); isP2Entry := IsZero(output.state - StatePoseidonEcall()); - addPC := NormalizeU32(AddU32(inst_input.pc_u32, ValU32((isDecode + isP2Entry) * 4, 0))); + isShaEcall := IsZero(output.state - StateShaEcall()); + addPC := NormalizeU32(AddU32(inst_input.pc_u32, ValU32((isDecode + isP2Entry + isShaEcall) * 4, 0))); InstOutput(addPC, output.state, 1) } diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_p2.zir b/zirgen/circuit/rv32im/v2/dsl/inst_p2.zir index 637b69b0..46b23108 100644 --- a/zirgen/circuit/rv32im/v2/dsl/inst_p2.zir +++ b/zirgen/circuit/rv32im/v2/dsl/inst_p2.zir @@ -464,6 +464,7 @@ component PoseidonPaging(cycle: Reg, mode: Val, prev: PoseidonState) { component Poseidon0(cycle:Reg, inst_input: InstInput) { DoCycleTable(cycle); + inst_input.state = StatePoseidonEntry() + inst_input.minor; state : PoseidonState; state := inst_input.minor_onehot -> ( PoseidonEntry(cycle, inst_input.pc_u32, inst_input.mode), @@ -480,6 +481,7 @@ component Poseidon0(cycle:Reg, inst_input: InstInput) { component Poseidon1(cycle:Reg, inst_input: InstInput) { DoCycleTable(cycle); + inst_input.state = StatePoseidonExtRound() + inst_input.minor; state : PoseidonState; state := inst_input.minor_onehot -> ( PoseidonExtRound(state@1), diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_sha.zir b/zirgen/circuit/rv32im/v2/dsl/inst_sha.zir new file mode 100644 index 00000000..7078ac50 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/dsl/inst_sha.zir @@ -0,0 +1,240 @@ +// RUN: true + +import consts; +import inst; +import inst_p2; // To get ReadAddr, maybe should move that somewhere else +import sha2; + +component ShaState( + a: Array, + e: Array, + w: Array, + stateInAddr: Val, + stateOutAddr: Val, + dataAddr: Val, + count: Val, + kAddr: Val, + round: Val, + nextState: Val) +{ + public stateInAddr := Reg(stateInAddr); + public stateOutAddr := Reg(stateOutAddr); + public dataAddr := Reg(dataAddr); + public count := Reg(count); + public kAddr := Reg(kAddr); + public round := Reg(round); + public nextState := Reg(nextState); + public a := for b : a { NondetReg(b) }; + public e := for b : e { NondetReg(b) }; + public w := for b : w { NondetReg(b) }; +} + +component ShaInvalid() { + 0 = 1; + ShaState( + for i : 0..32 { 0 }, + for i : 0..32 { 0 }, + for i : 0..32 { 0 }, + 0, 0, 0, 0, 0, 0, + StateDecode() + ) +} + +component ShaEcall(cycle: Reg) { + Log("SHA ECALL"); + // Load values from registers + stateInAddr := ReadAddr(cycle, RegA0()); + stateOutAddr := ReadAddr(cycle, RegA1()); + dataAddr := ReadAddr(cycle, RegA2()); + Log("Data Addr: ", dataAddr); + count := MemoryRead(cycle, MachineRegBase() + RegA3()).low; + kAddr := ReadAddr(cycle, RegA4()); + ShaState( + for i : 0..32 { 0 }, + for i : 0..32 { 0 }, + for i : 0..32 { 0 }, + stateInAddr, + stateOutAddr, + dataAddr, + count, + kAddr, + 0, + StateShaLoadState() + ) +} + +component UnpackU32NondetLE(val: ValU32) { + UnpackNondet<32, 16>([val.low, val.high]); +} + +component UnpackU32NondetBE(val: ValU32) { + unpacked := UnpackNondet<32, 16>([val.low, val.high]); + for o : 0..32 { + j := o & 7; + i := (o - j) / 8; + unpacked[(3 - i) * 8 + j] + } +} + +component VerifyUnpackU32LE(unpacked: Array, orig: ValU32) { + packed := Pack<32, 16>(unpacked); + for i : 0..32 { AssertBit(unpacked[i]); }; + orig.low = packed[0]; + orig.high = packed[1]; +} + +component VerifyUnpackU32BE(unpacked: Array, orig: ValU32) { + packed := Pack<32, 8>(unpacked); + for i : 0..32 { AssertBit(unpacked[i]); }; + orig.low = packed[2] * 256 + packed[3]; + orig.high = packed[0] * 256 + packed[1]; +} + +component BitsToBE(unpacked: Array) { + packed := Pack<32, 8>(unpacked); + ValU32(packed[2] * 256 + packed[3], packed[0] * 256 + packed[1]) +} + +component ShaLoadState(cycle: Reg, prev: ShaState) { + lastRound := IsZero(3 - prev.round); + countZero := IsZero(prev.count); + nextState := if (lastRound) { + if (countZero) { + StateDecode() + } else { + StateShaLoadData() + } + } else { + StateShaLoadState() + }; + a32 := MemoryRead(cycle, prev.stateInAddr + 3 - prev.round); + e32 := MemoryRead(cycle, prev.stateInAddr + 7 - prev.round); + MemoryWrite(cycle, prev.stateOutAddr + 3 - prev.round, a32); + MemoryWrite(cycle, prev.stateOutAddr + 7 - prev.round, e32); + out := ShaState( + UnpackU32NondetBE(a32), + UnpackU32NondetBE(e32), + for i : 0..32 { 0 }, + prev.stateInAddr, + prev.stateOutAddr, + prev.dataAddr, + prev.count, + prev.kAddr, + (1 - lastRound) * (prev.round + 1), + nextState + ); + VerifyUnpackU32BE(out.a, a32); + VerifyUnpackU32BE(out.e, e32); + for i : 0..32 { out.w[i] = 0; }; + out +} + +component ShaLoadData(cycle: Reg, prev: ShaState, p2: ShaState, p3: ShaState, p4: ShaState) { + lastRound := IsZero(15 - prev.round); + k := MemoryRead(cycle, prev.kAddr + prev.round); + wMem := MemoryRead(cycle, prev.dataAddr); + wNondet := UnpackU32NondetBE(wMem); + wBits := for i : 0..32 { NondetReg(wNondet[i]) }; + VerifyUnpackU32BE(wBits, wMem); + ae := ComputeAE([prev.a, p2.a, p3.a, p4.a], [prev.e, p2.e, p3.e, p4.e], wBits, [k.low, k.high]); + Log("a = ", ae.rawA[0], ae.rawA[1]); + Log("e = ", ae.rawE[0], ae.rawE[1]); + a := CarryAndExpand(ae.rawA); + e := CarryAndExpand(ae.rawE); + out := ShaState( + a, + e, + wBits, + prev.stateInAddr, + prev.stateOutAddr, + prev.dataAddr + 1, + prev.count, + prev.kAddr, + (1 - lastRound) * (prev.round + 1), + lastRound * StateShaMix() + (1 - lastRound) * StateShaLoadData() + ); + AliasLayout!(a, out.a); + AliasLayout!(e, out.e); + AliasLayout!(wBits, out.w); + out +} + +component ShaMix(cycle: Reg, prev: ShaState, p2: ShaState, p3: ShaState, p4: ShaState, p7: ShaState, p15: ShaState, p16: ShaState) { + lastRound := IsZero(47 - prev.round); + k := MemoryRead(cycle, prev.kAddr + 16 + prev.round); + wRaw := ComputeWBack(p2.w, p7.w, p15.w, p16.w); + wBits := CarryAndExpand(wRaw); + ae := ComputeAE([prev.a, p2.a, p3.a, p4.a], [prev.e, p2.e, p3.e, p4.e], wBits, [k.low, k.high]); + Log("a = ", ae.rawA[0], ae.rawA[1]); + Log("e = ", ae.rawE[0], ae.rawE[1]); + a := CarryAndExpand(ae.rawA); + e := CarryAndExpand(ae.rawE); + out := ShaState( + a, + e, + wBits, + prev.stateInAddr, + prev.stateOutAddr, + prev.dataAddr, + prev.count, + prev.kAddr, + (1 - lastRound) * (prev.round + 1), + lastRound * StateShaStoreState() + (1 - lastRound) * StateShaMix() + ); + AliasLayout!(a, out.a); + AliasLayout!(e, out.e); + AliasLayout!(wBits, out.w); + out +} + +component ShaStoreState(cycle: Reg, prev: ShaState, p4: ShaState, p68: ShaState) { + lastRound := IsZero(3 - prev.round); + newCount := prev.count - lastRound; + countZero := IsZero(newCount); + nextState := if (countZero) { + StateDecode() + } else { + if (lastRound) { + StateShaLoadData() + } else { + StateShaStoreState() + } + }; + a := CarryAndExpand(Add2(Pack32(p4.a), Pack32(p68.a))); + e := CarryAndExpand(Add2(Pack32(p4.e), Pack32(p68.e))); + out := ShaState( + a, + e, + for i : 0..32 { 0 }, + prev.stateInAddr, + prev.stateOutAddr, + prev.dataAddr, + newCount, + prev.kAddr, + (1 - lastRound) * (prev.round + 1), + nextState + ); + AliasLayout!(a, out.a); + AliasLayout!(e, out.e); + MemoryWrite(cycle, prev.stateOutAddr + 3 - prev.round, BitsToBE(a)); + MemoryWrite(cycle, prev.stateOutAddr + 7 - prev.round, BitsToBE(e)); + out +} + +component Sha0(cycle:Reg, inst_input: InstInput) { + DoCycleTable(cycle); + inst_input.state = StateShaEcall() + inst_input.minor; + state : ShaState; + state := inst_input.minor_onehot -> ( + ShaEcall(cycle), + ShaLoadState(cycle, state@1), // 4 cycles, load A/E from state input + ShaLoadData(cycle, state@1, state@2, state@3, state@4), // 16 cycles, load data + ShaMix(cycle, state@1, state@2, state@3, state@4, state@7, state@15, state@16), // 48 cycles, do internal mixing + ShaStoreState(cycle, state@1, state@4, state@68), // 4 cycles, update state (in place) + ShaInvalid(), + ShaInvalid(), + ShaInvalid() + ); + InstOutput(inst_input.pc_u32, state.nextState, inst_input.mode) +} + diff --git a/zirgen/circuit/rv32im/v2/dsl/mem.zir b/zirgen/circuit/rv32im/v2/dsl/mem.zir index a5b07963..e0923b9e 100644 --- a/zirgen/circuit/rv32im/v2/dsl/mem.zir +++ b/zirgen/circuit/rv32im/v2/dsl/mem.zir @@ -101,7 +101,7 @@ component MemoryWrite(cycle: Reg, addr: Val, data: ValU32) { // Let the host write anythings (used in host read words) component MemoryWriteUnconstrained(cycle: Reg, addr: Val) { - io := MemoryIO(2*cycle + 1, addr); + public io := MemoryIO(2*cycle + 1, addr); IsForward(io); } diff --git a/zirgen/circuit/rv32im/v2/dsl/mult.zir b/zirgen/circuit/rv32im/v2/dsl/mult.zir index eed61631..da6ba6eb 100644 --- a/zirgen/circuit/rv32im/v2/dsl/mult.zir +++ b/zirgen/circuit/rv32im/v2/dsl/mult.zir @@ -151,6 +151,8 @@ component MultiplyAccumulate(a: ValU32, b: ValU32, c: ValU32, settings: Multiply s3Carry := FakeTwitReg((s3Tot - s3Out) / 0x10000); public outLow := ValU32(s0.out, s1.out); public outHigh := ValU32(s2.out, s3Out); + public aNeg := ax.neg; + public bNeg := bx.neg; } component MultiplyTestCase(a: ValU32, b: ValU32, c: ValU32, settings: MultiplySettings, ol: ValU32, oh: ValU32) { diff --git a/zirgen/circuit/rv32im/v2/dsl/pack.zir b/zirgen/circuit/rv32im/v2/dsl/pack.zir new file mode 100644 index 00000000..3165419c --- /dev/null +++ b/zirgen/circuit/rv32im/v2/dsl/pack.zir @@ -0,0 +1,45 @@ +// RUN: zirgen -I %S --test %s +// Bit packing and unpacking logic + +import bits; +import arr; +import u32; + +// We have 3 functions here: +// 1) Pack an array of N bits into N/P elements +// 2) Unpack an array of N bits from N/P elements *without* verifying +// 3) Same as above, but registerize and verify +// We don't handle 'uneven' packings, and rely on external code to make +// sure the 'parts' eveny divide the whole. + +// Pack N bits into parts of P bits each +component Pack(in : Array) { + N % P = 0; + for i : 0..(N / P) { + reduce for j : 0..P { Po2(j) * in[i*P+ j] } init 0 with Add + } +} + +component UnpackNondet(in: Array) { + N % P = 0; + inv := Inv(P); + for n : 0..N { + j := n % P; + i := (n - j) * inv; + (in[i] & Po2(j)) / Po2(j) + } +} + +component UnpackReg(in: Array) { + bitVals := UnpackNondet(in); + bits := for n : 0..N { NondetBitReg(bitVals[n]) }; + EqArr(Pack(bits), in); + bits +} + +test PackUnpack { + bits := UnpackReg<16, 4>([1, 15, 5, 10]); + EqArr<16>(bits, [1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1]); + oneVal := Pack<16, 16>(bits); + oneVal[0] = 0xa5f1; +} diff --git a/zirgen/circuit/rv32im/v2/dsl/po2.zir b/zirgen/circuit/rv32im/v2/dsl/po2.zir index bb144b0c..cf3b6029 100644 --- a/zirgen/circuit/rv32im/v2/dsl/po2.zir +++ b/zirgen/circuit/rv32im/v2/dsl/po2.zir @@ -4,25 +4,16 @@ import lookups; -// TODO: This is a lame workaround to the fact that map/reduce has issues with constants +// The max we can pack into one BB element is 30 bits component Po2(n: Val) { - arr := [ - 0x1, - 0x2, - 0x4, - 0x8, - 0x10, - 0x20, - 0x40, - 0x80, - 0x100, - 0x200, - 0x400, - 0x800, - 0x1000, - 0x2000, - 0x4000, - 0x8000 + arr := [ 0x00000001, 0x00000002, 0x00000004, 0x00000008, + 0x00000010, 0x00000020, 0x00000040, 0x00000080, + 0x00000100, 0x00000200, 0x00000400, 0x00000800, + 0x00001000, 0x00002000, 0x00004000, 0x00008000, + 0x00010000, 0x00020000, 0x00040000, 0x00080000, + 0x00100000, 0x00200000, 0x00400000, 0x00800000, + 0x01000000, 0x02000000, 0x04000000, 0x08000000, + 0x10000000, 0x20000000 ]; arr[n] } diff --git a/zirgen/circuit/rv32im/v2/dsl/sha2.zir b/zirgen/circuit/rv32im/v2/dsl/sha2.zir new file mode 100644 index 00000000..dba762d6 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/dsl/sha2.zir @@ -0,0 +1,182 @@ +// RUN: zirgen -I %S --test %s + +import arr; +import bits; +import one_hot; +import pack; + +component XorU32(a: Array, b: Array) { + for i : 0..32 { + BitXor(a[i], b[i]) + } +} + +component MajU32(a: Array, b: Array, c: Array) { + for i : 0..32 { + a[i] * b[i] * (1 - c[i]) + + a[i] * (1 - b[i]) * c[i] + + (1 - a[i]) * b[i] * c[i] + + a[i] * b[i] * c[i] + } +} + +component ChU32(a: Array, b: Array, c: Array) { + for i : 0..32 { + a[i] * b[i] + (1 - a[i]) * c[i] + } +} + +component Add2(a: Array, b: Array) { + for i : 0..2 { a[i] + b[i] } +} + +component CarryExtract(in: Val) { + bit0 := NondetBitReg(((in & 0xf0000) / 0x10000) & 1); + bit1 := NondetBitReg((((in & 0xf0000) / 0x10000) & 2) / 2); + bit2 := NondetBitReg((((in & 0xf0000) / 0x10000) & 4) / 4); + public carry := bit2 * 4 + bit1 * 2 + bit0; + public out := in - carry * 0x10000; +} + +component CarryAndExpand(in: Array) { + lowCarry := CarryExtract(in[0]); + highCarry := CarryExtract(in[1] + lowCarry.carry); + out := UnpackReg<32, 16>([lowCarry.out, highCarry.out]); + out +} + +component ExpandBE(in: Array) { + original := UnpackNondet<32, 16>(in); + // Swap bytes, maintain bits, registerize + out := for i : 0..32 { + bit := i & 7; + byte := (i - bit) * Inv(8); + NondetBitReg(original[(3 - byte) * 8 + bit]) + }; + // 'Pack' into bytes + bytes := Pack<32, 8>(out); + // Verify byteswapping it back matches + in[0] = bytes[2] * 256 + bytes[3]; + in[1] = bytes[0] * 256 + bytes[1]; + // Return output + out +} + +component PushFront(in: Array, lst: Array, SIZE>) { + for i : 0..SIZE { + if (Isz(i)) { in } + else { lst[i - 1] } + } +} + +component Pack32(x: Array) { Pack<32, 16>(x) } + +// Given old Ws, produce new 'raw' W +component ComputeW(ow : Array, 16>) { + s0 := XorU32(RotateRight<32>(ow[14], 7), XorU32(RotateRight<32>(ow[14], 18), ShiftRight<32>(ow[14], 3))); + s1 := XorU32(RotateRight<32>(ow[1], 17), XorU32(RotateRight<32>(ow[1], 19), ShiftRight<32>(ow[1], 10))); + rawW := Add2(Pack32(s0), Add2(Pack32(s1), Add2(Pack32(ow[15]), Pack32(ow[6])))); + rawW +} + +// Same as above, but compute directly from specific 'backs' +component ComputeWBack(w2 : Array, w7 : Array, w15 : Array, w16: Array) { + s0 := XorU32(RotateRight<32>(w15, 7), XorU32(RotateRight<32>(w15, 18), ShiftRight<32>(w15, 3))); + s1 := XorU32(RotateRight<32>(w2, 17), XorU32(RotateRight<32>(w2, 19), ShiftRight<32>(w2, 10))); + rawW := Add2(Pack32(s0), Add2(Pack32(s1), Add2(Pack32(w16), Pack32(w7)))); + rawW +} +// component PX(x: Array) { Pack<32, 16>(x)[0] } + +// Given old A/E, new W, and k, produce new A/E +component ComputeAE(oa: Array, 4>, oe: Array, 4>, w: Array, k: Array) { + a := oa[0]; b := oa[1]; c := oa[2]; d := oa[3]; + e := oe[0]; f := oe[1]; g := oe[2]; h := oe[3]; + // Log("----", PX(a), PX(b), PX(c), PX(d), PX(e), PX(f), PX(g), PX(h)); + s0 := XorU32(RotateRight<32>(a, 2), XorU32(RotateRight<32>(a, 13), RotateRight<32>(a, 22))); + s1 := XorU32(RotateRight<32>(e, 6), XorU32(RotateRight<32>(e, 11), RotateRight<32>(e, 25))); + stage1 := Add2(Pack32(w), Add2(k, Add2(Pack32(h), Add2(Pack32(ChU32(e, f, g)), Pack32(s1))))); + public rawA := Add2(stage1, Add2(Pack32(MajU32(a, b, c)), Pack32(s0))); + public rawE := Add2(stage1, Pack32(d)); +} + +component TableK() { + [ + [0x2f98, 0x428a], [0x4491, 0x7137], [0xfbcf, 0xb5c0], [0xdba5, 0xe9b5], + [0xc25b, 0x3956], [0x11f1, 0x59f1], [0x82a4, 0x923f], [0x5ed5, 0xab1c], + [0xaa98, 0xd807], [0x5b01, 0x1283], [0x85be, 0x2431], [0x7dc3, 0x550c], + [0x5d74, 0x72be], [0xb1fe, 0x80de], [0x06a7, 0x9bdc], [0xf174, 0xc19b], + [0x69c1, 0xe49b], [0x4786, 0xefbe], [0x9dc6, 0x0fc1], [0xa1cc, 0x240c], + [0x2c6f, 0x2de9], [0x84aa, 0x4a74], [0xa9dc, 0x5cb0], [0x88da, 0x76f9], + [0x5152, 0x983e], [0xc66d, 0xa831], [0x27c8, 0xb003], [0x7fc7, 0xbf59], + [0x0bf3, 0xc6e0], [0x9147, 0xd5a7], [0x6351, 0x06ca], [0x2967, 0x1429], + [0x0a85, 0x27b7], [0x2138, 0x2e1b], [0x6dfc, 0x4d2c], [0x0d13, 0x5338], + [0x7354, 0x650a], [0x0abb, 0x766a], [0xc92e, 0x81c2], [0x2c85, 0x9272], + [0xe8a1, 0xa2bf], [0x664b, 0xa81a], [0x8b70, 0xc24b], [0x51a3, 0xc76c], + [0xe819, 0xd192], [0x0624, 0xd699], [0x3585, 0xf40e], [0xa070, 0x106a], + [0xc116, 0x19a4], [0x6c08, 0x1e37], [0x774c, 0x2748], [0xbcb5, 0x34b0], + [0x0cb3, 0x391c], [0xaa4a, 0x4ed8], [0xca4f, 0x5b9c], [0x6ff3, 0x682e], + [0x82ee, 0x748f], [0x636f, 0x78a5], [0x7814, 0x84c8], [0x0208, 0x8cc7], + [0xfffa, 0x90be], [0x6ceb, 0xa450], [0xa3f7, 0xbef9], [0x78f2, 0xc671] + ] +} + +component InitA() { + [ + [0xe667, 0x6a09], + [0xae85, 0xbb67], + [0xf372, 0x3c6e], + [0xf53a, 0xa54f] + ] +} + +component InitE() { + [ + [0x527f, 0x510e], + [0x688c, 0x9b05], + [0xd9ab, 0x1f83], + [0xcd19, 0x5be0] + ] +} + +// A version of state used in testing SHA256 +component TestState(a: Array, 4>, e: Array, 4>, w: Array, 16>) { + public a := a; + public e := e; + public w := w; + flatA := Pack<32, 16>(a[0]); + flatE := Pack<32, 16>(e[0]); + Log("a = %x %x, e = %x %x", flatA[1], flatA[0], flatE[1], flatE[0]); +} + +component GetK(round: Val) { + oneHot := OneHot<64>(round); + table := TableK(); + for i : 0..2 { + reduce for j : 0..64 { oneHot[j] * table[j][i] } init 0 with Add + } +} + +component DoTestStepLoad(in: TestState, round: Val) { + comp := ComputeAE(in.a, in.e, in.w[round], GetK(round)); + a := CarryAndExpand(comp.rawA); + e := CarryAndExpand(comp.rawE); + TestState(PushFront<4>(a, in.a), PushFront<4>(e, in.e), in.w) +} + +component DoTestStepMix(in: TestState, round: Val) { + w := CarryAndExpand(ComputeW(in.w)); + comp := ComputeAE(in.a, in.e, w, GetK(16 + round)); + a := CarryAndExpand(comp.rawA); + e := CarryAndExpand(comp.rawE); + TestState(PushFront<4>(a, in.a), PushFront<4>(e, in.e), PushFront<16>(w, in.w)) +} + +test TestVector { + initState := TestState( + for i : 0..4 { UnpackReg<32, 16>(InitA()[i]) }, + for i : 0..4 { UnpackReg<32, 16>(InitE()[i]) }, + for i : 0..16 { for j : 0..32 { Reg(0) } }); + afterLoads := reduce 0..16 init initState with DoTestStepLoad; + finalState := reduce 0..48 init afterLoads with DoTestStepMix; +} diff --git a/zirgen/circuit/rv32im/v2/dsl/top.zir b/zirgen/circuit/rv32im/v2/dsl/top.zir index e9657663..4f91fc68 100644 --- a/zirgen/circuit/rv32im/v2/dsl/top.zir +++ b/zirgen/circuit/rv32im/v2/dsl/top.zir @@ -1,9 +1,5 @@ // RUN: true -// TODO: Now that the v2 circuit uses an extern to compute major/minor it no -// longer makes sense to do rv32im conformance testing here. Make sure -// integration tests are covering this. - import inst_div; import inst_misc; import inst_mul; @@ -11,6 +7,7 @@ import inst_mem; import inst_control; import inst_ecall; import inst_p2; +import inst_sha; import mem; import one_hot; @@ -70,7 +67,7 @@ component Top() { // Make a nice input to all the instructions inst_input := InstInput(major, minor, pc_u32, state, machine_mode); // Now we split on major - major_onehot := OneHot<11>(major); + major_onehot := OneHot<12>(major); inst_result := major_onehot ->! ( Misc0(cycle, inst_input), Misc1(cycle, inst_input), @@ -82,7 +79,8 @@ component Top() { Control0(cycle, inst_input), ECall0(cycle, inst_input), Poseidon0(cycle, inst_input), - Poseidon1(cycle, inst_input) + Poseidon1(cycle, inst_input), + Sha0(cycle, inst_input) ); // Compute next PC pc_word := inst_result.new_pc.low / 4 + inst_result.new_pc.high * 16384; diff --git a/zirgen/circuit/rv32im/v2/emu/BUILD.bazel b/zirgen/circuit/rv32im/v2/emu/BUILD.bazel index 11c45cf3..af28db8c 100644 --- a/zirgen/circuit/rv32im/v2/emu/BUILD.bazel +++ b/zirgen/circuit/rv32im/v2/emu/BUILD.bazel @@ -18,6 +18,7 @@ cc_library( "paging.h", "preflight.h", "r0vm.h", + "sha.h", "trace.h", ], deps = [ diff --git a/zirgen/circuit/rv32im/v2/emu/exec.cpp b/zirgen/circuit/rv32im/v2/emu/exec.cpp index af08dd49..7af2ec27 100644 --- a/zirgen/circuit/rv32im/v2/emu/exec.cpp +++ b/zirgen/circuit/rv32im/v2/emu/exec.cpp @@ -56,6 +56,12 @@ struct ExecContext { } physCycles++; } + void shaCycle(uint32_t cur, const ShaState& state) { + if (debug) { + std::cout << "sha: " << state.nextState << "\n"; + } + physCycles++; + } void trapRewind() {} void trap(TrapCause cause) {} diff --git a/zirgen/circuit/rv32im/v2/emu/preflight.cpp b/zirgen/circuit/rv32im/v2/emu/preflight.cpp index dd5e7c23..2f217ed6 100644 --- a/zirgen/circuit/rv32im/v2/emu/preflight.cpp +++ b/zirgen/circuit/rv32im/v2/emu/preflight.cpp @@ -135,6 +135,14 @@ struct PreflightContext { cycleCompleteSpecial(curState, p2.nextState, pc); physCycles++; } + void shaCycle(uint32_t curState, ShaState sha) { + if (debug) { + std::cout << trace.cycles.size() << " shaCycle\n"; + } + sha.write(trace.extra); + cycleCompleteSpecial(curState, sha.nextState, pc); + physCycles++; + } void trapRewind() { trace.txns.resize(memCycle); @@ -213,6 +221,7 @@ struct PreflightContext { } size_t rlen = segment.readRecord[curRead].size(); memcpy(data, segment.readRecord[curRead].data(), rlen); + curRead++; return rlen; } @@ -319,7 +328,7 @@ PreflightTrace preflightSegment(const Segment& in, size_t segmentSize) { // Now, go back and update memory transactions to wrap around for (auto& txn : ret.txns) { - if (txn.prevCycle == -1) { + if (static_cast(txn.prevCycle) == -1) { // If first cycle for word, set to 'prevCycle' to final cycle txn.prevCycle = preflightContext.prevCycle[txn.word]; } else { diff --git a/zirgen/circuit/rv32im/v2/emu/r0vm.h b/zirgen/circuit/rv32im/v2/emu/r0vm.h index 51d5fc0d..84bec4b8 100644 --- a/zirgen/circuit/rv32im/v2/emu/r0vm.h +++ b/zirgen/circuit/rv32im/v2/emu/r0vm.h @@ -26,6 +26,7 @@ #include "zirgen/compiler/zkp/poseidon2.h" #include "zirgen/circuit/rv32im/v2/emu/p2.h" +#include "zirgen/circuit/rv32im/v2/emu/sha.h" namespace zirgen::rv32im_v2 { @@ -172,18 +173,22 @@ template struct R0Context { std::vector bytes(len); rlen = context.read(fd, bytes.data(), len); storeReg(REG_A0, rlen); + uint32_t i = 0; if (rlen == 0) { context.pc += 4; } context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen); curState = nextState(ptr, rlen); - uint32_t i = 0; while (rlen > 0 && ptr % 4 != 0) { writeByte(ptr, bytes[i]); - // context.hostReadBytes(ptr); ptr++; i++; rlen--; + if (rlen == 0) { + context.pc += 4; + } + context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen); + curState = nextState(ptr, rlen); } while (rlen >= 4) { uint32_t words = std::min(rlen / 4, uint32_t(4)); @@ -194,12 +199,12 @@ template struct R0Context { word |= bytes[i + k] << (8 * k); } storeMem(ptr / 4, word); + ptr += 4; + i += 4; + rlen -= 4; } else { storeMem(SAFE_WRITE_WORD, 0); } - ptr += words; - i += words; - rlen -= words; } if (rlen == 0) { context.pc += 4; @@ -207,12 +212,16 @@ template struct R0Context { context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen); curState = nextState(ptr, rlen); } - while (rlen > 0 && ptr % 4 != 0) { + while (rlen > 0) { writeByte(ptr, bytes[i]); - // context.hostReadBytes(ptr); ptr++; i++; rlen--; + if (rlen == 0) { + context.pc += 4; + } + context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen); + curState = nextState(ptr, rlen); } return false; } @@ -250,6 +259,14 @@ template struct R0Context { return false; } + bool doSha2() { + // Bump PC + context.pc += 4; + context.ecallCycle(STATE_MACHINE_ECALL, STATE_SHA_ECALL, 0, 0, 0); + ShaECall(context); + return false; + } + // Machine mode ECALL, allow for overrides in subclasses bool doMachineECALL() { switch (loadReg(REG_A7)) { @@ -261,6 +278,8 @@ template struct R0Context { return doHostWrite(); case HOST_ECALL_POSEIDON2: return doPoseidon2(); + case HOST_ECALL_SHA2: + return doSha2(); default: throw std::runtime_error("unimplemented machine ECALL"); } diff --git a/zirgen/circuit/rv32im/v2/emu/sha.h b/zirgen/circuit/rv32im/v2/emu/sha.h new file mode 100644 index 00000000..7b8a62ee --- /dev/null +++ b/zirgen/circuit/rv32im/v2/emu/sha.h @@ -0,0 +1,169 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace zirgen::rv32im_v2 { + +// 1 to 1 state from inst_sha +struct ShaState { + static constexpr size_t FpCount = 7; // Number of Fp values + static constexpr size_t U32Count = 3; // Number of U32 value + uint32_t stateInAddr; + uint32_t stateOutAddr; + uint32_t dataAddr; + uint32_t count; + uint32_t kAddr; + uint32_t round; + uint32_t nextState; + uint32_t a; + uint32_t e; + uint32_t w; + + void write(std::vector& out) { + const uint32_t* data = reinterpret_cast(this); + for (size_t i = 0; i < sizeof(ShaState) / 4; i++) { + out.push_back(data[i]); + } + } + + void read(const uint32_t* in, size_t count) { + assert(count == sizeof(ShaState) / 4); + uint32_t* data = reinterpret_cast(this); + for (size_t i = 0; i < count; i++) { + data[i] = in[i]; + } + } +}; + +template struct RingBuffer { + std::array buf; + uint32_t cur = 0; + uint32_t back(size_t i) const { return buf[(size + cur - i) % size]; } + void push(uint32_t val) { + buf[cur] = val; + cur++; + cur %= size; + } +}; + +#define ROTRIGHT(a, b) (((a) >> (b)) | ((a) << (32 - (b)))) +#define CH(x, y, z) (((x) & (y)) ^ (~(x) & (z))) +#define MAJ(x, y, z) (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z))) +#define EP0(x) (ROTRIGHT(x, 2) ^ ROTRIGHT(x, 13) ^ ROTRIGHT(x, 22)) +#define EP1(x) (ROTRIGHT(x, 6) ^ ROTRIGHT(x, 11) ^ ROTRIGHT(x, 25)) +#define SIG0(x) (ROTRIGHT(x, 7) ^ ROTRIGHT(x, 18) ^ ((x) >> 3)) +#define SIG1(x) (ROTRIGHT(x, 17) ^ ROTRIGHT(x, 19) ^ ((x) >> 10)) + +inline std::pair +computeAE(const RingBuffer<68>& oldA, const RingBuffer<68>& oldE, uint32_t k, uint32_t w) { + uint32_t a = oldA.back(1); + uint32_t b = oldA.back(2); + uint32_t c = oldA.back(3); + uint32_t d = oldA.back(4); + uint32_t e = oldE.back(1); + uint32_t f = oldE.back(2); + uint32_t g = oldE.back(3); + uint32_t h = oldE.back(4); + uint32_t t1 = h + EP1(e) + CH(e, f, g) + k + w; + uint32_t t2 = EP0(a) + MAJ(a, b, c); + e = d + t1; + a = t1 + t2; + return std::make_pair(a, e); +} + +inline uint32_t computeW(const RingBuffer<16>& oldW) { + return SIG1(oldW.back(2)) + oldW.back(7) + SIG0(oldW.back(15)) + oldW.back(16); +} + +template void ShaECall(Context& context) { + ShaState sha; + sha.stateInAddr = context.load(MACHINE_REGS_WORD + REG_A0) / 4; + sha.stateOutAddr = context.load(MACHINE_REGS_WORD + REG_A1) / 4; + sha.dataAddr = context.load(MACHINE_REGS_WORD + REG_A2) / 4; + sha.count = context.load(MACHINE_REGS_WORD + REG_A3) & 0xffff; + sha.kAddr = context.load(MACHINE_REGS_WORD + REG_A4) / 4; + sha.round = 0; + sha.a = 0; + sha.e = 0; + sha.w = 0; + uint32_t curState = STATE_SHA_ECALL; + auto step = [&](uint32_t nextState) { + sha.nextState = nextState; + context.shaCycle(curState, sha); + curState = nextState; + }; + RingBuffer<68> oldA; + RingBuffer<68> oldE; + RingBuffer<16> oldW; + for (size_t i = 0; i < 4; i++) { + sha.round = i; + step(STATE_SHA_LOAD_STATE); + uint32_t leA = context.load(sha.stateInAddr + 3 - i); + uint32_t leE = context.load(sha.stateInAddr + 7 - i); + sha.a = htonl(leA); + sha.e = htonl(leE); + oldA.push(sha.a); + oldE.push(sha.e); + context.store(sha.stateOutAddr + 3 - i, leA); + context.store(sha.stateOutAddr + 7 - i, leE); + } + while (sha.count != 0) { + for (size_t i = 0; i < 16; i++) { + sha.round = i; + step(STATE_SHA_LOAD_DATA); + uint32_t k = context.load(sha.kAddr + i); + sha.w = htonl(context.load(sha.dataAddr)); + sha.dataAddr++; + oldW.push(sha.w); + auto ae = computeAE(oldA, oldE, k, sha.w); + sha.a = ae.first; + sha.e = ae.second; + oldA.push(sha.a); + oldE.push(sha.e); + } + for (size_t i = 0; i < 48; i++) { + sha.round = i; + step(STATE_SHA_MIX); + uint32_t k = context.load(sha.kAddr + 16 + i); + sha.w = computeW(oldW); + oldW.push(sha.w); + auto ae = computeAE(oldA, oldE, k, sha.w); + sha.a = ae.first; + sha.e = ae.second; + oldA.push(sha.a); + oldE.push(sha.e); + } + for (size_t i = 0; i < 4; i++) { + sha.round = i; + step(STATE_SHA_STORE_STATE); + sha.a = oldA.back(4) + oldA.back(68); + sha.e = oldE.back(4) + oldE.back(68); + sha.w = 0; + if (i == 3) { + sha.count--; + } + oldA.push(sha.a); + oldE.push(sha.e); + context.store(sha.stateOutAddr + 3 - i, htonl(sha.a)); + context.store(sha.stateOutAddr + 7 - i, htonl(sha.e)); + } + } + sha.round = 0; + step(STATE_DECODE); +} + +} // namespace zirgen::rv32im_v2 diff --git a/zirgen/circuit/rv32im/v2/platform/constants.h b/zirgen/circuit/rv32im/v2/platform/constants.h index 7375e957..99bab963 100644 --- a/zirgen/circuit/rv32im/v2/platform/constants.h +++ b/zirgen/circuit/rv32im/v2/platform/constants.h @@ -95,6 +95,7 @@ constexpr uint32_t HOST_ECALL_TERMINATE = 0; constexpr uint32_t HOST_ECALL_READ = 1; constexpr uint32_t HOST_ECALL_WRITE = 2; constexpr uint32_t HOST_ECALL_POSEIDON2 = 3; +constexpr uint32_t HOST_ECALL_SHA2 = 4; constexpr uint32_t PFLAG_IS_ELEM = 0x80000000; constexpr uint32_t PFLAG_CHECK_OUT = 0x40000000; @@ -140,7 +141,13 @@ constexpr uint32_t STATE_POSEIDON_STORE_STATE = 23; constexpr uint32_t STATE_POSEIDON_EXT_ROUND = 24; constexpr uint32_t STATE_POSEIDON_INT_ROUND = 25; -constexpr uint32_t STATE_DECODE = 32; +constexpr uint32_t STATE_SHA_ECALL = 32; +constexpr uint32_t STATE_SHA_LOAD_STATE = 33; +constexpr uint32_t STATE_SHA_LOAD_DATA = 34; +constexpr uint32_t STATE_SHA_MIX = 35; +constexpr uint32_t STATE_SHA_STORE_STATE = 36; + +constexpr uint32_t STATE_DECODE = 40; constexpr uint32_t SAFE_WRITE_WORD = 0x3fffc040; diff --git a/zirgen/circuit/rv32im/v2/run/run.cpp b/zirgen/circuit/rv32im/v2/run/run.cpp index 68137890..5861d7c9 100644 --- a/zirgen/circuit/rv32im/v2/run/run.cpp +++ b/zirgen/circuit/rv32im/v2/run/run.cpp @@ -227,11 +227,22 @@ ExecutionTrace runSegment(const Segment& segment, size_t segmentSize) { for (size_t j = 0; j < extraSize; j++) { trace.data.set(i, getEcall0StateCol() + j, preflightTrace.extra[extraStart + j]); } - } else { + } else if (extraSize == sizeof(P2State) / 4) { for (size_t j = 0; j < extraSize; j++) { // std::cout << " extra: " << preflightTrace.extra[extraStart + j] << "\n"; trace.data.set(i, getPoseidonStateCol() + j, preflightTrace.extra[extraStart + j]); } + } else if (extraSize == sizeof(ShaState) / 4) { + for (size_t j = 0; j < ShaState::FpCount; j++) { + trace.data.set(i, getShaStateCol() + j, preflightTrace.extra[extraStart + j]); + } + for (size_t j = 0; j < ShaState::U32Count; j++) { + uint32_t val = preflightTrace.extra[extraStart + ShaState::FpCount + j]; + std::cout << " SHA_WORD: " << val << "\n"; + for (size_t k = 0; k < 32; k++) { + trace.data.set(i, getShaStateCol() + ShaState::FpCount + 32 * j + k, (val >> k) & 1); + } + } } } diff --git a/zirgen/circuit/rv32im/v2/run/wrap_dsl.cpp b/zirgen/circuit/rv32im/v2/run/wrap_dsl.cpp index 96c5e4af..f8111b25 100644 --- a/zirgen/circuit/rv32im/v2/run/wrap_dsl.cpp +++ b/zirgen/circuit/rv32im/v2/run/wrap_dsl.cpp @@ -76,6 +76,9 @@ ExtVal inv_0(ExtVal x) { Val bitAnd(Val a, Val b) { return Val(a.asUInt32() & b.asUInt32()); } +Val mod(Val a, Val b) { + return Val(a.asUInt32() % b.asUInt32()); +} Val inRange(Val low, Val mid, Val high) { assert(low <= high); return Val(low <= mid && mid < high); @@ -341,6 +344,10 @@ size_t getPoseidonStateCol() { return impl::kLayout_Top.instResult.arm9.state.hasState._super.col; } +size_t getShaStateCol() { + return impl::kLayout_Top.instResult.arm11.state.stateInAddr._super.col; +} + void DslStep(StepHandler& stepHandler, ExecutionTrace& trace, size_t cycle) { impl::ExecContext ctx(stepHandler, trace, cycle); impl::MutableBufObj data(ctx, trace.data); diff --git a/zirgen/circuit/rv32im/v2/run/wrap_dsl.h b/zirgen/circuit/rv32im/v2/run/wrap_dsl.h index d7e28a08..5b781501 100644 --- a/zirgen/circuit/rv32im/v2/run/wrap_dsl.h +++ b/zirgen/circuit/rv32im/v2/run/wrap_dsl.h @@ -39,6 +39,7 @@ size_t getCycleCol(); size_t getTopStateCol(); size_t getEcall0StateCol(); size_t getPoseidonStateCol(); +size_t getShaStateCol(); void DslStep(StepHandler& stepHandler, ExecutionTrace& trace, size_t cycle); void DslStepAccum(StepHandler& stepHandler, ExecutionTrace& trace, size_t cycle); diff --git a/zirgen/circuit/rv32im/v2/test/BUILD.bazel b/zirgen/circuit/rv32im/v2/test/BUILD.bazel index 0f1f2134..90b7ee36 100644 --- a/zirgen/circuit/rv32im/v2/test/BUILD.bazel +++ b/zirgen/circuit/rv32im/v2/test/BUILD.bazel @@ -1,18 +1,14 @@ load("@zirgen//bazel/toolchain/rv32im-linux:defs.bzl", "risc0_cc_kernel_binary") +load(":defs.bzl", "riscv_test_suite") cc_test( name = "test_parallel", - srcs = [ - "test_parallel.cpp", - ], + srcs = ["test_parallel.cpp"], data = [ "//zirgen/circuit/rv32im/v2/emu/test:guest", "//zirgen/circuit/rv32im/v2/kernel", - "@zirgen//zirgen/circuit/rv32im/shared/test:riscv_test_bins", - ], - deps = [ - "//zirgen/circuit/rv32im/v2/run", ], + deps = ["//zirgen/circuit/rv32im/v2/run"], ) risc0_cc_kernel_binary( @@ -26,13 +22,57 @@ risc0_cc_kernel_binary( cc_test( name = "test_p2", + srcs = ["test_p2.cpp"], + data = [":test_p2_kernel"], + deps = ["//zirgen/circuit/rv32im/v2/run"], +) + +risc0_cc_kernel_binary( + name = "test_sha_kernel", + srcs = [ + "entry.s", + "test_sha_kernel.cpp", + ], + deps = ["//zirgen/circuit/rv32im/v2/platform:core"], +) + +cc_test( + name = "test_sha", + srcs = ["test_sha.cpp"], + data = [":test_sha_kernel"], + deps = ["//zirgen/circuit/rv32im/v2/run"], +) + +cc_binary( + name = "risc0-simulate", + srcs = ["risc0-simulate.cpp"], + deps = [ + "//risc0/core", + "//zirgen/circuit/rv32im/v2/run", + ], +) + +risc0_cc_kernel_binary( + name = "test_io_kernel", + srcs = [ + "entry.s", + "test_io_kernel.cpp", + ], + deps = ["//zirgen/circuit/rv32im/v2/platform:core"], +) + +cc_test( + name = "test_io", srcs = [ - "test_p2.cpp", + "test_io.cpp", ], data = [ - ":test_p2_kernel", + ":test_io_kernel", ], deps = [ + "//risc0/core", "//zirgen/circuit/rv32im/v2/run", - ], + ] ) + +riscv_test_suite() diff --git a/zirgen/circuit/rv32im/v2/test/defs.bzl b/zirgen/circuit/rv32im/v2/test/defs.bzl new file mode 100644 index 00000000..7164ebe9 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/defs.bzl @@ -0,0 +1,63 @@ +INST_TESTS = [ + "add", + "addi", + "and", + "andi", + "auipc", + "beq", + "bge", + "bgeu", + "blt", + "bltu", + "bne", + "jal", + "jalr", + "lb", + "lbu", + "lh", + "lhu", + "lui", + "lw", + "or", + "ori", + "sb", + "sh", + "simple", + "sll", + "slli", + "slt", + "slti", + "sltiu", + "sltu", + "sra", + "srai", + "srl", + "srli", + "sub", + "sw", + "xor", + "xori", + "div", + "divu", + "mul", + "mulh", + "mulhsu", + "mulhu", + "rem", + "remu", +] + +def riscv_test_suite(): + for test in INST_TESTS: + native.py_test( + # tags = ["manual"], + name = test + "_test", + srcs = ["run_test.py"], + main = "run_test.py", + args = [test], + data = [ + "//zirgen/circuit/rv32im/shared/test:riscv_test_bins", + ":risc0-simulate", + ], + size = "large", + ) diff --git a/zirgen/circuit/rv32im/v2/test/risc0-simulate.cpp b/zirgen/circuit/rv32im/v2/test/risc0-simulate.cpp new file mode 100644 index 00000000..6992aae3 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/risc0-simulate.cpp @@ -0,0 +1,47 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "risc0/core/log.h" +#include "zirgen/circuit/rv32im/v2/platform/constants.h" +#include "zirgen/circuit/rv32im/v2/run/run.h" + +using namespace zirgen::rv32im_v2; + +int main(int argc, char* argv[]) { + risc0::setLogLevel(2); + if (argc < 2) { + LOG(1, "usage: risc0-simulate "); + exit(1); + } + + LOG(1, "File = " << argv[1]); + try { + size_t cycles = 10000; + + TestIoHandler io; + + // Load image + auto image = MemoryImage::fromRawElf(argv[1]); + // Do executions + auto segments = execute(image, io, cycles, cycles); + // Do 'run' (preflight + expansion) + for (const auto& segment : segments) { + runSegment(segment, cycles); + } + } catch (const std::runtime_error& err) { + LOG(1, "Failed: " << err.what()); + exit(1); + } + return 0; +} diff --git a/zirgen/circuit/rv32im/v2/test/run_test.py b/zirgen/circuit/rv32im/v2/test/run_test.py new file mode 100644 index 00000000..aa23af2c --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/run_test.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# Copyright 2022 RISC Zero, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import sys + +sys.exit( + subprocess.run( + [ + "zirgen/circuit/rv32im/v2/test/risc0-simulate", + "zirgen/circuit/rv32im/shared/test/" + sys.argv[1], + ] + ).returncode +) diff --git a/zirgen/circuit/rv32im/v2/test/test_io.cpp b/zirgen/circuit/rv32im/v2/test/test_io.cpp new file mode 100644 index 00000000..8be2f803 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/test_io.cpp @@ -0,0 +1,50 @@ +// Copyright 2025 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "zirgen/circuit/rv32im/v2/platform/constants.h" +#include "zirgen/circuit/rv32im/v2/run/run.h" + +using namespace zirgen::rv32im_v2; + +const std::string kernelName = "zirgen/circuit/rv32im/v2/test/test_io_kernel"; + +// Allows reads of any size, fill with a pattern to check in kernel +struct RandomReadSizeHandler : public HostIoHandler { + uint32_t write(uint32_t fd, const uint8_t* data, uint32_t len) override { return len; } + uint32_t read(uint32_t fd, uint8_t* data, uint32_t len) override { + std::cout << "DOING READ OF SIZE " << len << "\n"; + for (size_t i = 0; i < len; i++) { + data[i] = i; + } + return len; + } +}; + +int main() { + size_t cycles = 100000; + RandomReadSizeHandler io; + + // Load image + auto image = MemoryImage::fromRawElf(kernelName); + // Do executions + auto segments = execute(image, io, cycles, cycles); + // Do 'run' (preflight + expansion) + for (const auto& segment : segments) { + std::cout << "HEY, doing a segment!\n"; + runSegment(segment, cycles + 1000); + } + std::cout << "What a fine day\n"; +} diff --git a/zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp b/zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp new file mode 100644 index 00000000..7099f727 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp @@ -0,0 +1,85 @@ +// Copyright 2025 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "zirgen/circuit/rv32im/v2/platform/constants.h" + +using namespace zirgen::rv32im_v2; + +inline void die() { + asm("fence\n"); +} + +// Implement machine mode ECALLS + +inline void terminate(uint32_t val) { + register uintptr_t a0 asm("a0") = val; + register uintptr_t a7 asm("a7") = 0; + asm volatile("ecall\n" + : // no outputs + : "r"(a0), "r"(a7) // inputs + : // no clobbers + ); +} + +inline uint32_t host_read(uint32_t fd, uint32_t buf, uint32_t len) { + register uintptr_t a0 asm("a0") = fd; + register uintptr_t a1 asm("a1") = buf; + register uintptr_t a2 asm("a2") = len; + register uintptr_t a7 asm("a7") = 1; + asm volatile("ecall\n" + : "+r"(a0) // outputs + : "r"(a0), "r"(a1), "r"(a2), "r"(a7) // inputs + : // no clobbers + ); + return a0; +} + +inline uint32_t host_write(uint32_t fd, uint32_t buf, uint32_t len) { + register uintptr_t a0 asm("a0") = fd; + register uintptr_t a1 asm("a1") = buf; + register uintptr_t a2 asm("a2") = len; + register uintptr_t a7 asm("a7") = 2; + asm volatile("ecall\n" + : "+r"(a0) // outputs + : "r"(a0), "r"(a1), "r"(a2), "r"(a7) // inputs + : // no clobbers + ); + return a0; +} + +constexpr uint32_t sizes[11] = {0, 1, 2, 3, 4, 5, 7, 13, 19, 40, 101}; + +void test_multi_read() { + uint8_t buf[200]; + // Try all 4 alignments + for (size_t i = 0; i < 4; i++) { + // Try a variety of size + for (size_t j = 0; j < 11; j++) { + host_read(0, (uint32_t)(buf + i), sizes[j]); + for (size_t k = 0; k < sizes[j]; k++) { + if (buf[i + k] != k) { + die(); + } + } + } + } +} + +extern "C" void start() { + test_multi_read(); + terminate(0); +} diff --git a/zirgen/circuit/rv32im/v2/test/test_p2.cpp b/zirgen/circuit/rv32im/v2/test/test_p2.cpp index 34caa5f7..d2635082 100644 --- a/zirgen/circuit/rv32im/v2/test/test_p2.cpp +++ b/zirgen/circuit/rv32im/v2/test/test_p2.cpp @@ -25,24 +25,8 @@ int main() { size_t cycles = 100000; TestIoHandler io; - auto entry = 0x10000; - auto pc = entry / 4; - - auto image = MemoryImage::fromWords({ - {pc + 0, 0x1234b337}, // lui x6, 0x0001234b - {pc + 1, 0xf387e3b7}, // lui x7, 0x000f387e - {pc + 2, 0x007302b3}, // add x5, x6, x7 - {pc + 3, 0x000045b7}, // lui x11, 0x00000004 - {pc + 4, 0x00000073}, // ecall - {SUSPEND_PC_WORD, entry}, - {SUSPEND_MODE_WORD, 1}, - }); - - std::cout << image.getDigest(0x400100) << std::endl; - std::cout << image.getDigest(1) << std::endl; - // Load image - // auto image = MemoryImage::fromRawElf(kernelName); + auto image = MemoryImage::fromRawElf(kernelName); // Do executions auto segments = execute(image, io, cycles, cycles); // Do 'run' (preflight + expansion) diff --git a/zirgen/circuit/rv32im/v2/test/test_sha.cpp b/zirgen/circuit/rv32im/v2/test/test_sha.cpp new file mode 100644 index 00000000..98c71347 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/test_sha.cpp @@ -0,0 +1,36 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "zirgen/circuit/rv32im/v2/platform/constants.h" +#include "zirgen/circuit/rv32im/v2/run/run.h" + +using namespace zirgen::rv32im_v2; + +const std::string kernelName = "zirgen/circuit/rv32im/v2/test/test_sha_kernel"; + +int main() { + size_t cycles = 100000; + TestIoHandler io; + + // Load image + auto image = MemoryImage::fromRawElf(kernelName); + // Do executions + auto segments = execute(image, io, cycles, cycles); + // Do 'run' (preflight + expansion) + for (const auto& segment : segments) { + runSegment(segment, cycles + 1000); + } +} diff --git a/zirgen/circuit/rv32im/v2/test/test_sha_kernel.cpp b/zirgen/circuit/rv32im/v2/test/test_sha_kernel.cpp new file mode 100644 index 00000000..2cab4f37 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/test_sha_kernel.cpp @@ -0,0 +1,156 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "zirgen/circuit/rv32im/v2/platform/constants.h" + +using namespace zirgen::rv32im_v2; + +inline void die() { + asm("fence\n"); +} + +// Implement machine mode ECALLS + +inline void terminate(uint32_t val) { + register uintptr_t a0 asm("a0") = val; + register uintptr_t a7 asm("a7") = 0; + asm volatile("ecall\n" + : // no outputs + : "r"(a0), "r"(a7) // inputs + : // no clobbers + ); +} + +static constexpr uint32_t SHA_K[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +}; + +inline void do_sha2(uint32_t stateIn, uint32_t stateOut, uint32_t data, uint32_t count) { + register uintptr_t a0 asm("a0") = stateIn; + register uintptr_t a1 asm("a1") = stateOut; + register uintptr_t a2 asm("a2") = data; + register uintptr_t a3 asm("a3") = count; + register uintptr_t a4 asm("a4") = (uint32_t)SHA_K; + register uintptr_t a7 asm("a7") = 4; + asm volatile("ecall\n" + : // no outputs + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(a4), "r"(a7) // inputs + : // no clobbers + ); +} + +constexpr uint32_t swap(uint32_t n) { + return (((n & 0x000000ff) << 24) | ((n & 0x0000ff00) << 8) | ((n & 0x00ff0000) >> 8) | + ((n & 0xff000000) >> 24)); +} + +static constexpr uint32_t SHA_INIT[8] = { + swap(0x6a09e667), + swap(0xbb67ae85), + swap(0x3c6ef372), + swap(0xa54ff53a), + swap(0x510e527f), + swap(0x9b05688c), + swap(0x1f83d9ab), + swap(0x5be0cd19), +}; + +static constexpr uint8_t parseHex(char x) { + if (x >= 'a' && x <= 'f') { + return 10 + x - 'a'; + } + if (x >= '0' && x <= '9') { + return x - '0'; + } + die(); + return 0; +} + +void compareHex(uint32_t* words, const char* str) { + uint8_t* asBytes = reinterpret_cast(words); + for (size_t i = 0; i < 32; i++) { + uint8_t highNibble = parseHex(*str++); + uint8_t lowNibble = parseHex(*str++); + if (asBytes[i] != highNibble * 16 + lowNibble) { + die(); + } + } +} + +void shaPad(uint8_t* out, const char* in) { + uint32_t bits = 0; + while (*in) { + *out++ = *in++; + bits += 8; + } + uint32_t outBits = bits; + *out++ = 0x80; + bits += 8; + while (bits % 512 != 0) { + *out++ = 0; + bits += 8; + } + out -= 2; + out[0] = outBits / 256; + out[1] = outBits % 256; +} + +void test_sha_zero_blocks() { + uint32_t state[8]; + + do_sha2((uint32_t)SHA_INIT, (uint32_t)state, 0, 0); + + compareHex(state, "6a09e667bb67ae853c6ef372a54ff53a510e527f9b05688c1f83d9ab5be0cd19"); +} + +void test_sha_one_block() { + uint32_t state[8]; + uint64_t data[16]; + + shaPad(reinterpret_cast(data), "abc"); + + do_sha2((uint32_t)SHA_INIT, (uint32_t)state, (uint32_t)data, 1); + + compareHex(state, "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"); +} + +void test_sha_two_blocks() { + uint32_t state[8]; + uint64_t data[32]; + + shaPad(reinterpret_cast(data), + "abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrl" + "mnopqrsmnopqrstnopqrstu"); + + do_sha2((uint32_t)SHA_INIT, (uint32_t)state, (uint32_t)data, 2); + + compareHex(state, "cf5b16a778af8380036ce59e7b0492370b249b11e8f07a51afac45037afee9d1"); +} + +extern "C" void start() { + test_sha_zero_blocks(); + test_sha_one_block(); + test_sha_two_blocks(); + terminate(0); +}