Skip to content

Commit

Permalink
refactor: batch circuits (#1447)
Browse files Browse the repository at this point in the history
* 1. Update StartIndex to uint64
2. Add validation checks
3. Add missing unit tests

* 1. Updated StartIndex from uint32 to uint64

* Add negative test cases for circuit validation logic
  • Loading branch information
sergeytimoshin authored Jan 3, 2025
1 parent 864a1a1 commit 32df070
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 16 deletions.
6 changes: 3 additions & 3 deletions prover/server/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ func testBatchAddressAppendWithPreviousState40_100(t *testing.T) {
}

func runBatchAddressAppendWithPreviousStateTest(t *testing.T, treeHeight uint32, batchSize uint32) {
startIndex := uint32(2)
startIndex := uint64(2)
params1, err := prover.BuildTestAddressTree(treeHeight, batchSize, nil, startIndex)
if err != nil {
t.Fatalf("Failed to build first test tree: %v", err)
Expand All @@ -487,7 +487,7 @@ func runBatchAddressAppendWithPreviousStateTest(t *testing.T, treeHeight uint32,
}
response1.Body.Close()

startIndex += batchSize
startIndex += uint64(batchSize)
params2, err := prover.BuildTestAddressTree(treeHeight, batchSize, params1.Tree, startIndex)
if err != nil {
t.Fatalf("Failed to build second test tree: %v", err)
Expand Down Expand Up @@ -521,7 +521,7 @@ func runBatchAddressAppendWithPreviousStateTest(t *testing.T, treeHeight uint32,
func testBatchAddressAppendInvalidInput40_10(t *testing.T) {
treeHeight := uint32(40)
batchSize := uint32(10)
startIndex := uint32(0)
startIndex := uint64(0)

params, err := prover.BuildTestAddressTree(treeHeight, batchSize, nil, startIndex)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion prover/server/prover/batch_address_append_circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ type BatchAddressAppendParameters struct {
OldRoot *big.Int
NewRoot *big.Int
HashchainHash *big.Int
StartIndex uint32
StartIndex uint64

LowElementValues []big.Int
LowElementIndices []big.Int
Expand Down
145 changes: 143 additions & 2 deletions prover/server/prover/batch_address_append_circuit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestBatchAddressAppendCircuit(t *testing.T) {
name string
treeHeight uint32
batchSize uint32
startIndex uint32
startIndex uint64
shouldPass bool
}{
{"Single insert height 4", 4, 1, 2, true},
Expand Down Expand Up @@ -83,8 +83,9 @@ func TestBatchAddressAppendCircuit(t *testing.T) {
name string
treeHeight uint32
batchSize uint32
startIndex uint32
startIndex uint64
modifyParams func(*BatchAddressAppendParameters)
wantPanic bool
}{
{
name: "Invalid OldRoot",
Expand Down Expand Up @@ -122,6 +123,138 @@ func TestBatchAddressAppendCircuit(t *testing.T) {
p.LowElementValues[0].Add(&p.LowElementValues[0], big.NewInt(1))
},
},
{
name: "StartIndex too large",
treeHeight: 4,
batchSize: 1,
startIndex: 0,
modifyParams: func(p *BatchAddressAppendParameters) {
p.StartIndex = ^uint64(0)
},
},
{
name: "Mismatched array length",
treeHeight: 4,
batchSize: 2,
startIndex: 0,
modifyParams: func(p *BatchAddressAppendParameters) {
p.LowElementValues = p.LowElementValues[:len(p.LowElementValues)-1]
},
wantPanic: true,
},
{
name: "Invalid proof length",
treeHeight: 4,
batchSize: 2,
startIndex: 0,
modifyParams: func(p *BatchAddressAppendParameters) {
p.LowElementProofs[0] = p.LowElementProofs[0][:len(p.LowElementProofs[0])-1]
},
wantPanic: true,
},
{
name: "Empty arrays",
treeHeight: 4,
batchSize: 2,
startIndex: 0,
modifyParams: func(p *BatchAddressAppendParameters) {
p.LowElementValues = make([]big.Int, p.BatchSize)
p.NewElementValues = make([]big.Int, p.BatchSize)
},
},
{
name: "Max values",
treeHeight: 4,
batchSize: 1,
startIndex: 0,
modifyParams: func(p *BatchAddressAppendParameters) {
maxBigInt := new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1))
p.NewElementValues[0] = *maxBigInt
},
},
{
name: "Inconsistent start index with proofs",
treeHeight: 4,
batchSize: 1,
startIndex: 0,
modifyParams: func(p *BatchAddressAppendParameters) {
p.StartIndex = 5
},
},
{
name: "Low element below expected range",
treeHeight: 4,
batchSize: 1,
startIndex: 2,
modifyParams: func(p *BatchAddressAppendParameters) {
p.LowElementValues[0].Sub(&p.LowElementValues[0], big.NewInt(1))
},
},
{
name: "Low element above expected range",
treeHeight: 4,
batchSize: 1,
startIndex: 2,
modifyParams: func(p *BatchAddressAppendParameters) {
// Set low element value above valid range
maxVal := new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil)
p.LowElementValues[0].Add(&p.LowElementValues[0], maxVal)
},
},
{
name: "Invalid low element next indices",
treeHeight: 4,
batchSize: 1,
startIndex: 2,
modifyParams: func(p *BatchAddressAppendParameters) {
p.LowElementNextIndices[0].Add(&p.LowElementNextIndices[0], big.NewInt(5))
},
},
{
name: "Invalid low element next values",
treeHeight: 4,
batchSize: 1,
startIndex: 2,
modifyParams: func(p *BatchAddressAppendParameters) {
p.LowElementNextValues[0].Add(&p.LowElementNextValues[0], big.NewInt(1))
},
},
{
name: "Invalid low element indices",
treeHeight: 4,
batchSize: 1,
startIndex: 2,
modifyParams: func(p *BatchAddressAppendParameters) {
p.LowElementIndices[0].Add(&p.LowElementIndices[0], big.NewInt(3))
},
},
{
name: "Invalid low element proofs",
treeHeight: 4,
batchSize: 1,
startIndex: 2,
modifyParams: func(p *BatchAddressAppendParameters) {
p.LowElementProofs[0][0].Add(&p.LowElementProofs[0][0], big.NewInt(1))
},
},
{
name: "Invalid new element proofs",
treeHeight: 4,
batchSize: 1,
startIndex: 2,
modifyParams: func(p *BatchAddressAppendParameters) {
p.NewElementProofs[0][0].Add(&p.NewElementProofs[0][0], big.NewInt(1))
},
},
{
name: "Invalid new element values",
treeHeight: 4,
batchSize: 1,
startIndex: 2,
modifyParams: func(p *BatchAddressAppendParameters) {
p.NewElementValues[0].Add(&p.NewElementValues[0], big.NewInt(1))
},
},
}

for _, tc := range testCases {
Expand All @@ -135,6 +268,14 @@ func TestBatchAddressAppendCircuit(t *testing.T) {

tc.modifyParams(params)

if tc.wantPanic {
assert.Panics(func() {
witness, _ := params.CreateWitness()
test.IsSolved(&circuit, witness, ecc.BN254.ScalarField())
})
return
}

witness, err := params.CreateWitness()
if err != nil {
return
Expand Down
2 changes: 1 addition & 1 deletion prover/server/prover/batch_append_with_proofs_circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ type BatchAppendWithProofsParameters struct {
LeavesHashchainHash *big.Int
Leaves []*big.Int
MerkleProofs [][]big.Int
StartIndex uint32
StartIndex uint64
Height uint32
BatchSize uint32
Tree *merkle_tree.PoseidonTree
Expand Down
154 changes: 154 additions & 0 deletions prover/server/prover/batch_append_with_proofs_circuit_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package prover

import (
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc"
Expand Down Expand Up @@ -111,4 +112,157 @@ func TestBatchAppendWithProofsCircuit(t *testing.T) {
err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.NoError(err)
})

t.Run("Invalid public input hash", func(t *testing.T) {
treeDepth := 10
batchSize := 2
params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false)
params.PublicInputHash = big.NewInt(999)

witness := createTestWitness(params)
circuit := createTestCircuit(treeDepth, batchSize)

err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.Error(err)
})

t.Run("Invalid old root", func(t *testing.T) {
treeDepth := 10
batchSize := 2
params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false)
params.OldRoot = big.NewInt(999)

witness := createTestWitness(params)
circuit := createTestCircuit(treeDepth, batchSize)

err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.Error(err)
})

t.Run("Invalid new root", func(t *testing.T) {
treeDepth := 10
batchSize := 2
params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false)
params.NewRoot = big.NewInt(999)

witness := createTestWitness(params)
circuit := createTestCircuit(treeDepth, batchSize)

err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.Error(err)
})

t.Run("Invalid leaves hashchain", func(t *testing.T) {
treeDepth := 10
batchSize := 2
params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false)
params.LeavesHashchainHash = big.NewInt(999)

witness := createTestWitness(params)
circuit := createTestCircuit(treeDepth, batchSize)

err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.Error(err)
})

t.Run("Invalid merkle proof", func(t *testing.T) {
treeDepth := 10
batchSize := 2
params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false)
params.MerkleProofs[0][0] = *big.NewInt(999)

witness := createTestWitness(params)
circuit := createTestCircuit(treeDepth, batchSize)

err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.Error(err)
})

t.Run("Invalid start index", func(t *testing.T) {
treeDepth := 10
batchSize := 2
params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false)
params.StartIndex = uint64(1 << treeDepth)

witness := createTestWitness(params)
circuit := createTestCircuit(treeDepth, batchSize)

err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.Error(err)
})

t.Run("Invalid old leaves", func(t *testing.T) {
assert := test.NewAssert(t)
treeDepth := 10
batchSize := 2
params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false)

params.OldLeaves[0] = big.NewInt(999)

witness := createTestWitness(params)
circuit := createTestCircuit(treeDepth, batchSize)

err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.Error(err)
})

t.Run("Invalid leaves", func(t *testing.T) {
assert := test.NewAssert(t)
treeDepth := 10
batchSize := 2
params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false)

params.Leaves[0] = big.NewInt(999)

witness := createTestWitness(params)
circuit := createTestCircuit(treeDepth, batchSize)

err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
assert.Error(err)
})
}

func createTestCircuit(treeDepth, batchSize int) BatchAppendWithProofsCircuit {
circuit := BatchAppendWithProofsCircuit{
PublicInputHash: frontend.Variable(0),
OldRoot: frontend.Variable(0),
NewRoot: frontend.Variable(0),
LeavesHashchainHash: frontend.Variable(0),
OldLeaves: make([]frontend.Variable, batchSize),
Leaves: make([]frontend.Variable, batchSize),
StartIndex: frontend.Variable(0),
MerkleProofs: make([][]frontend.Variable, batchSize),
Height: uint32(treeDepth),
BatchSize: uint32(batchSize),
}

for i := range circuit.MerkleProofs {
circuit.MerkleProofs[i] = make([]frontend.Variable, treeDepth)
}
return circuit
}

func createTestWitness(params *BatchAppendWithProofsParameters) BatchAppendWithProofsCircuit {
witness := BatchAppendWithProofsCircuit{
PublicInputHash: frontend.Variable(params.PublicInputHash),
OldRoot: frontend.Variable(params.OldRoot),
NewRoot: frontend.Variable(params.NewRoot),
LeavesHashchainHash: frontend.Variable(params.LeavesHashchainHash),
OldLeaves: make([]frontend.Variable, int(params.BatchSize)),
Leaves: make([]frontend.Variable, int(params.BatchSize)),
MerkleProofs: make([][]frontend.Variable, int(params.BatchSize)),
StartIndex: frontend.Variable(params.StartIndex),
Height: params.Height,
BatchSize: params.BatchSize,
}

for i := 0; i < int(params.BatchSize); i++ {
witness.Leaves[i] = frontend.Variable(params.Leaves[i])
witness.OldLeaves[i] = frontend.Variable(params.OldLeaves[i])
witness.MerkleProofs[i] = make([]frontend.Variable, params.Height)
for j := 0; j < int(params.Height); j++ {
witness.MerkleProofs[i][j] = frontend.Variable(params.MerkleProofs[i][j])
}
}
return witness
}
Loading

0 comments on commit 32df070

Please sign in to comment.