Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-305: Audit fixes #140

Merged
merged 2 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion zirgen/circuit/rv32im/v2/dsl/decode.zir
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ component Decoder(inst: ValU32) {
_f3_01 := NondetTwitReg((inst.low & 0x3000) / 0x1000);
_rd_34 := NondetTwitReg((inst.low & 0x0C00) / 0x0400);
_rd_12 := NondetTwitReg((inst.low & 0x0300) / 0x0100);
_rd_0 := NondetTwitReg((inst.low & 0x0080) / 0x0080);
_rd_0 := NondetBitReg((inst.low & 0x0080) / 0x0080);

// The opcode is special and is unconstrained.
// This implies the for the decoding to be fully correct, some later
Expand Down
11 changes: 10 additions & 1 deletion zirgen/circuit/rv32im/v2/dsl/inst_div.zir
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,22 @@ component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) {
rem := ValU32(rem_low, rem_high);
// Either all signed, or nothing signed
settings := MultiplySettings(signed, signed, signed);
// Do the acuumulate
// Do the accumulate
mul := MultiplyAccumulate(quot, denom, rem, settings);
// Check the main result (numer = quot * denom + rem
AssertEqU32(mul.outLow, numer);
// The top bits should all be 0 or all be 1
topBitType := NondetBitReg(1 - Isz(mul.outHigh.low));
AssertEqU32(mul.outHigh, ValU32(0xffff * topBitType, 0xffff * topBitType));
// Check if denom is zero
isZero := IsZero(denom.low + denom.high);
// If non-zero, make sure 0 <= rem < denom
if (isZero) {
AssertEqU32(rem, numer);
} else {
cmp := CmpLessThanUnsigned(rem, denom);
cmp.is_less_than = 1;
};
DivideReturn(quot, rem)
}

Expand Down
3 changes: 2 additions & 1 deletion zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ component ECallTerminate(cycle: Reg, input: InstInput) {

component DecomposeLow2(len: Val) {
// We split len into a multiple of 4, and the low 2 bits as one hot
public high := NondetReg((len & 0xfffc) / 4);
public high := NondetU16Reg((len & 0xfffc) / 4);
public low2 := NondetReg(len & 0x3);
len = 4*high + low2;
public low2Hot := OneHot<4>(low2);
public highZero := IsZero(high);
public isZero := Reg(highZero * low2Hot[0]);
Expand Down
36 changes: 28 additions & 8 deletions zirgen/circuit/rv32im/v2/dsl/inst_p2.zir
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,32 @@ component PoseidonCheckOut(cycle: Reg, prev: PoseidonState) {
PoseidonState(GetDef(prev), nextState, 0, 0, 0, prev.mode, prev.inner, MakeExt(0))
}

component FieldToWord(val: Val) {
// Decompose a field element into two u16s
public low := NondetU16Reg(val & 0xffff);
public high := U16Reg((val - low) / 65536);
// Check decomposition is unique
// If low == 0, high must be < 30720, otherwise high must be <= 30719
// Guess if low is zero
lowIsZero := NondetBitReg(Isz(low));
// Now check results: Technically, prover could set low-is-zero to false even if
// low was zero, but this only results in a stricter check of high, so it's pointless
if (lowIsZero) {
low = 0;
U16Reg(30720 - high);
} else {
U16Reg(30719 - high);
};
// Return as u32
public ret := ValU32(low, high);
}

component PoseidonStoreOut(cycle: Reg, prev: PoseidonState) {
for i : 0..8 {
val := prev.inner[i];
low := NondetU16Reg(val & 0xffff);
high := U16Reg((val - low) / 65536);
MemoryWrite(cycle, prev.bufOutAddr + i, ValU32(low, high));
ftw := FieldToWord(prev.inner[i]);
mw := MemoryWrite(cycle, prev.bufOutAddr + i, ftw.ret);
AliasLayout!(mw.io.newTxn.dataLow, ftw.low.arg.val);
AliasLayout!(mw.io.newTxn.dataHigh, ftw.high.arg.val);
};
isNormal := IsZero(prev.loadTxType - TxKindRead());
outState := isNormal * StateDecode() + (1 - isNormal) * StatePoseidonPaging();
Expand All @@ -298,10 +318,10 @@ component PoseidonDoOut(cycle: Reg, prev: PoseidonState) {

component PoseidonStoreState(cycle: Reg, prev: PoseidonState) {
for i : 0..8 {
val := prev.inner[16 + i];
low := NondetU16Reg(val & 0xffff);
high := U16Reg((val - low) / 65536);
MemoryWrite(cycle, prev.stateAddr+ i, ValU32(low, high));
ftw := FieldToWord(prev.inner[16 + i]);
mw := MemoryWrite(cycle, prev.stateAddr+ i, ftw.ret);
AliasLayout!(mw.io.newTxn.dataLow, ftw.low.arg.val);
AliasLayout!(mw.io.newTxn.dataHigh, ftw.high.arg.val);
};
PoseidonState(GetDef(prev), StateDecode(), 0, 0, 0, prev.mode, prev.inner, MakeExt(0))
}
Expand Down
6 changes: 3 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/lookups.zir
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ component NondetU8Reg(val: Val) {
component U8Reg(val: Val) {
ret := NondetU8Reg(val);
ret = val;
val
ret
}

argument ArgU16(count: Val, val: Val) {
Expand All @@ -32,7 +32,7 @@ argument ArgU16(count: Val, val: Val) {

// Set a register nodeterministically, and then verify it is a U16
component NondetU16Reg(val: Val) {
arg := ArgU16(1, val);
public arg := ArgU16(1, val);
arg.count = 1;
arg.val
}
Expand All @@ -41,7 +41,7 @@ component NondetU16Reg(val: Val) {
component U16Reg(val: Val) {
ret := NondetU16Reg(val);
ret = val;
val
ret
}

// TODO: Tests
Expand Down
2 changes: 1 addition & 1 deletion zirgen/circuit/rv32im/v2/dsl/mem.zir
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ component MemoryRead(cycle: Reg, addr: Val) {

// A normal memory write
component MemoryWrite(cycle: Reg, addr: Val, data: ValU32) {
io := MemoryIO(2*cycle + 1, addr);
public io := MemoryIO(2*cycle + 1, addr);
IsForward(io);
io.newTxn.dataLow = data.low;
io.newTxn.dataHigh = data.high;
Expand Down
Loading