Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/cargo/bytemuck-1.21.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mars-risc0 authored Jan 7, 2025
2 parents 95938ba + 17f3d89 commit 61885f2
Show file tree
Hide file tree
Showing 31 changed files with 1,382 additions and 87 deletions.
45 changes: 45 additions & 0 deletions zirgen/circuit/rv32im/v2/dsl/arr.zir
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: zirgen --test %s

// Vector / List functions

// Shifts + Rotates
component RotateLeft<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i - n, SIZE)) { in[i - n] } else { in[SIZE + i - n] }
}
}

component RotateRight<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i + n, SIZE)) { in[i + n] } else { in[i + n - SIZE] }
}
}

component ShiftLeft<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i - n, SIZE)) { in[i - n] } else { 0 }
}
}

component ShiftRight<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i + n, SIZE)) { in[i + n] } else { 0 }
}
}

component EqArr<SIZE: Val>(a: Array<Val, SIZE>, b: Array<Val, SIZE>) {
for i : 0..SIZE {
a[i] = b[i];
}
}

// Tests....

test ShiftAndRotate {
// Remember: array entry 0 is the low bit, so there seem backwards
EqArr<8>(ShiftRight<8>([3, 1, 5, 0, 2, 0, 0, 0], 2), [5, 0, 2, 0, 0, 0, 0, 0]);
EqArr<8>(ShiftLeft<8>([1, 4, 2, 0, 6, 0, 0, 0], 2), [0, 0, 1, 4, 2, 0, 6, 0]);
EqArr<8>(RotateRight<8>([7, 6, 1, 0, 2, 0, 0, 0], 2), [1, 0, 2, 0, 0, 0, 7, 6]);
EqArr<8>(RotateLeft<8>([4, 5, 1, 0, 1, 0, 0, 3], 2), [0, 3, 4, 5, 1, 0, 1, 0]);
}

10 changes: 7 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/bits.zir
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ function AssertTwit(val: Val) {
val * (1 - val) * (2 - val) * (3 - val) = 0;
}

// Simple bit ops
component BitAnd(a: Val, b: Val) {
Reg(a * b)
a * b
}

component BitOr(a: Val, b: Val) {
Reg(1 - (1 - a) * (1 - b))
a + b - a * b
}

component BitXor(a: Val, b: Val) {
a + b - 2 * a * b
}

// Set a register nodeterministically, and then verify it is a twit
Expand Down Expand Up @@ -81,4 +86,3 @@ test TwitInRange{
test_fails TwitOutOfRange {
AssertTwit(4);
}

9 changes: 8 additions & 1 deletion zirgen/circuit/rv32im/v2/dsl/consts.zir
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ component StatePoseidonStoreState() { 23 }
component StatePoseidonExtRound() { 24 }
component StatePoseidonIntRounds() { 25 }

component StateDecode() { 32 }
component StateShaEcall() { 32 }
component StateShaLoadState() { 33 }
component StateShaLoadData() { 34 }
component StateShaMix() { 35 }
component StateShaStoreState() { 36 }

component StateDecode() { 40 }

component RegA0() { 10 }
component RegA1() { 11 }
component RegA2() { 12 }
component RegA3() { 13 }
component RegA4() { 14 }

component RegA7() { 17 }
23 changes: 20 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/inst_div.zir
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,36 @@ component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) {
settings := MultiplySettings(signed, signed, signed);
// Do the accumulate
mul := MultiplyAccumulate(quot, denom, rem, settings);
// Check the main result (numer = quot * denom + rem
// Check the main result (numer = quot * denom + rem)
AssertEqU32(mul.outLow, numer);
// The top bits should all be 0 or all be 1
topBitType := NondetBitReg(1 - Isz(mul.outHigh.low));
AssertEqU32(mul.outHigh, ValU32(0xffff * topBitType, 0xffff * topBitType));
// Check if denom is zero
isZero := IsZero(denom.low + denom.high);
// Get top bit of numerator
topNum := NondetBitReg((numer.high & 0x8000) / 0x8000);
// Verify we got it right
U16Reg((numer.high - 0x8000 * topNum) * 2);
numNeg := topNum * signed;
// Get the absolute value of the denominator
denomNeg := mul.bNeg;
denomAbs := NormalizeU32(DenormedValU32(
denomNeg * (0x10000 - denom.low) + (1 - denomNeg) * denom.low,
denomNeg * (0xffff - denom.high) + (1 - denomNeg) * denom.high
));
// Flip the sign of the remainder if the numerator is negative
remNormal := NormalizeU32(DenormedValU32(
numNeg * (0x10000 - rem.low) + (1 - numNeg) * rem.low,
numNeg * (0xffff - rem.high) + (1 - numNeg) * rem.high
));
// Decide if we need to swap order of
// If non-zero, make sure 0 <= rem < denom
if (isZero) {
AssertEqU32(rem, numer);
} else {
cmp := CmpLessThanUnsigned(rem, denom);
cmp.is_less_than = 1;
lt := CmpLessThanUnsigned(remNormal, denomAbs);
lt.is_less_than = 1;
};
DivideReturn(quot, rem)
}
Expand Down
81 changes: 63 additions & 18 deletions zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inst;
import consts;


// Prepare to read a certain length, maybe return a smaller one
extern HostReadPrepare(fd: Val, len: Val): Val;

Expand All @@ -30,12 +29,13 @@ component MachineECall(cycle: Reg, input: InstInput, pc_addr: Val) {
input.mode = 1;
dispatch_idx := MemoryRead(cycle, MachineRegBase() + RegA7());
dispatch_idx.high = 0;
dispatch := OneHot<4>(dispatch_idx.low);
dispatch := OneHot<5>(dispatch_idx.low);
state := dispatch -> (
StateTerminate(),
StateHostReadSetup(),
StateHostWrite(),
StatePoseidonEcall()
StatePoseidonEcall(),
StateShaEcall()
);
ECallOutput(state, 0, 0, 0)
}
Expand Down Expand Up @@ -86,8 +86,8 @@ component ECallHostReadSetup(cycle: Reg, input: InstInput) {
lenDecomp := DecomposeLow2(newLen);
// Check if length is exactly 1, 2, or 3
len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero);
// Check if things are 'uneven'
uneven := Reg(len123 * ptrDecomp.low2Nonzero);
// Check if things are 'uneven' (this is an 'or')
uneven := Reg(len123 + ptrDecomp.low2Nonzero - len123 * ptrDecomp.low2Nonzero);
// Now pick the next cycle
nextCycle :=
// If length == 0, go back to decoding
Expand Down Expand Up @@ -117,36 +117,80 @@ component ECallHostWrite(cycle: Reg, input: InstInput) {
ECallOutput(StateDecode(), 0, 0, 0)
}

component ECallHostReadBytes(cycle: Reg, input: InstInput) {
// TODO
component ECallHostReadBytes(cycle: Reg, input: InstInput, ptrWord: Val, ptrLow2: Val, len: Val) {
input.state = StateHostReadBytes();
0 = 1;
ECallOutput(16, 0, 0, 0)
// Decompose next len
lenDecomp := DecomposeLow2(len - 1);
// Check if length is exactly 1, 2, or 3
len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero);
// Check is next pointer is even (this can only happen if Low2 == 3)
nextPtrEven := IsZero(ptrLow2 - 3);
nextPtrUneven := 1 - nextPtrEven;
nextPtrWord := nextPtrEven * (ptrWord + 1) + nextPtrUneven * ptrWord;
nextPtrLow2 := nextPtrUneven * (ptrLow2 + 1);
// Check if things are 'uneven' (this is an 'or')
uneven := Reg(len123 + nextPtrUneven - len123 * nextPtrUneven);
// Check is length is exactly zero
lenZero := IsZero(len - 1);
// Split low bits into parts
low0 := NondetBitReg(ptrLow2 & 1);
low1 := BitReg((ptrLow2 - low0) / 2);
// Load the original word
origWord := MemoryRead(cycle, ptrWord);
// Write the answer
io := MemoryWriteUnconstrained(cycle, ptrWord).io;
// Make the non-specified half matches
if (low1) {
origWord.low = io.newTxn.dataLow;
} else {
origWord.high = io.newTxn.dataHigh;
};
// Get the half that changed
oldHalf := low1 * origWord.high + (1 - low1) * origWord.low;
newHalf := low1 * io.newTxn.dataHigh + (1 - low1) * io.newTxn.dataLow;
// Split both into bytes
oldBytes := SplitWord(oldHalf);
newBytes := SplitWord(newHalf);
// Make sure the non-specified bytes matchs
if (low0) {
oldBytes.byte0 = newBytes.byte0;
} else {
oldBytes.byte1 = newBytes.byte1;
};
nextCycle :=
// If length == 0, go back to decoding
lenZero * StateDecode() +
// If length != 0 and uneven, do bytes
(1 - lenZero) * uneven * StateHostReadBytes() +
// If length != 0 and even, more words
(1 - lenZero) * (1 - uneven) * StateHostReadWords();
ECallOutput(nextCycle, nextPtrWord, nextPtrLow2, len - 1)
}

component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Val) {
input.state = StateHostReadWords();
lenDecomp := DecomposeLow2(len);
wordsDecomp := DecomposeLow2(lenDecomp.high);
doWord := [
wordsDecomp.low2Hot[1] * wordsDecomp.highZero,
wordsDecomp.low2Hot[2] * wordsDecomp.highZero,
wordsDecomp.low2Hot[3] * wordsDecomp.highZero,
(wordsDecomp.low2Hot[1] + wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3])* wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(1 - wordsDecomp.highZero)
];
count := reduce doWord init 0 with Add;
for i : 0..4 {
addr := Reg(doWord[i] * (ptrWord + i) + (1 - doWord[i]) * SafeWriteWord());
MemoryWriteUnconstrained(cycle, addr);
};
lenZero := IsZero(len - 4 * count);
newLenHighZero := IsZero(lenDecomp.high - count);
lenZero := Reg(newLenHighZero * (1 - lenDecomp.low2Nonzero));
nextCycle :=
// If length == 0, go back to decoding
lenZero * StateDecode() +
// If length != 0 and uneven, do bytes
(1 - lenZero) * (lenDecomp.low2Nonzero) * StateHostReadBytes() +
// If lengtj != 0 and even, more words
(1 - lenZero) * (1 - lenDecomp.low2Nonzero) * StateHostReadWords();
(1 - lenZero) * newLenHighZero * StateHostReadBytes() +
// If length != 0 and even, more words
(1 - lenZero) * (1 - newLenHighZero) * StateHostReadWords();
ECallOutput(nextCycle, ptrWord + count, 0, len - count * 4)
}

Expand All @@ -162,7 +206,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) {
ECallTerminate(cycle, inst_input),
ECallHostReadSetup(cycle, inst_input),
ECallHostWrite(cycle, inst_input),
ECallHostReadBytes(cycle, inst_input),
ECallHostReadBytes(cycle, inst_input, s0@1, s1@1, s2@1),
ECallHostReadWords(cycle, inst_input, s0@1, s2@1),
IllegalECall(),
IllegalECall()
Expand All @@ -172,6 +216,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) {
s2 := Reg(output.s2);
isDecode := IsZero(output.state - StateDecode());
isP2Entry := IsZero(output.state - StatePoseidonEcall());
addPC := NormalizeU32(AddU32(inst_input.pc_u32, ValU32((isDecode + isP2Entry) * 4, 0)));
isShaEcall := IsZero(output.state - StateShaEcall());
addPC := NormalizeU32(AddU32(inst_input.pc_u32, ValU32((isDecode + isP2Entry + isShaEcall) * 4, 0)));
InstOutput(addPC, output.state, 1)
}
2 changes: 2 additions & 0 deletions zirgen/circuit/rv32im/v2/dsl/inst_p2.zir
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ component PoseidonPaging(cycle: Reg, mode: Val, prev: PoseidonState) {

component Poseidon0(cycle:Reg, inst_input: InstInput) {
DoCycleTable(cycle);
inst_input.state = StatePoseidonEntry() + inst_input.minor;
state : PoseidonState;
state := inst_input.minor_onehot -> (
PoseidonEntry(cycle, inst_input.pc_u32, inst_input.mode),
Expand All @@ -480,6 +481,7 @@ component Poseidon0(cycle:Reg, inst_input: InstInput) {

component Poseidon1(cycle:Reg, inst_input: InstInput) {
DoCycleTable(cycle);
inst_input.state = StatePoseidonExtRound() + inst_input.minor;
state : PoseidonState;
state := inst_input.minor_onehot -> (
PoseidonExtRound(state@1),
Expand Down
Loading

0 comments on commit 61885f2

Please sign in to comment.