From 6c60c3a191f48cc0fc33f986aa7ef15f8f60e42f Mon Sep 17 00:00:00 2001 From: ananas-block Date: Wed, 15 Jan 2025 18:06:23 +0000 Subject: [PATCH 1/6] fix: wipe_previous_batch_bloom_filter --- forester-utils/src/instructions.rs | 27 +- program-libs/batched-merkle-tree/src/batch.rs | 36 +- .../batched-merkle-tree/src/batch_metadata.rs | 98 +++- program-libs/batched-merkle-tree/src/event.rs | 2 + .../src/initialize_state_tree.rs | 3 +- .../batched-merkle-tree/src/merkle_tree.rs | 453 +++++++++++------- program-libs/batched-merkle-tree/src/queue.rs | 62 +-- .../batched-merkle-tree/tests/merkle_tree.rs | 165 ++++--- program-libs/verifier/src/lib.rs | 6 +- .../tests/batched_merkle_tree_test.rs | 32 +- prover/client/src/mock_batched_forester.rs | 2 + .../program-test/src/test_batch_forester.rs | 19 +- 12 files changed, 512 insertions(+), 393 deletions(-) diff --git a/forester-utils/src/instructions.rs b/forester-utils/src/instructions.rs index 4571be3ef6..5ad8fe9477 100644 --- a/forester-utils/src/instructions.rs +++ b/forester-utils/src/instructions.rs @@ -1,8 +1,8 @@ use light_batched_merkle_tree::{ constants::{DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, DEFAULT_BATCH_STATE_TREE_HEIGHT}, merkle_tree::{ - AppendBatchProofInputsIx, BatchProofInputsIx, BatchedMerkleTreeAccount, - InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, + BatchedMerkleTreeAccount, InstructionDataBatchAppendInputs, + InstructionDataBatchNullifyInputs, }, queue::BatchedQueueAccount, }; @@ -58,12 +58,11 @@ where })? .unwrap(); - let (old_root_index, leaves_hashchain, start_index, current_root, batch_size, full_batch_index) = { + let (leaves_hashchain, start_index, current_root, batch_size, full_batch_index) = { let merkle_tree = BatchedMerkleTreeAccount::address_from_bytes(merkle_tree_account.data.as_mut_slice()) .unwrap(); - let old_root_index = merkle_tree.root_history.last_index(); let full_batch_index = merkle_tree.queue_metadata.next_full_batch_index; let batch = &merkle_tree.batches[full_batch_index as usize]; let zkp_batch_index = batch.get_num_inserted_zkps(); @@ -74,7 +73,6 @@ where let batch_size = batch.zkp_batch_size as usize; ( - old_root_index, leaves_hashchain, start_index, current_root, @@ -175,7 +173,7 @@ where })?; let client = Client::new(); - let circuit_inputs_new_root = bigint_to_be_bytes_array::<32>(&inputs.new_root).unwrap(); + let new_root = bigint_to_be_bytes_array::<32>(&inputs.new_root).unwrap(); let inputs = to_json(&inputs); let response_result = client @@ -192,10 +190,7 @@ where let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); let instruction_data = InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root: circuit_inputs_new_root, - old_root_index: old_root_index as u16, - }, + new_root, compressed_proof: CompressedProof { a: proof_a, b: proof_b, @@ -330,7 +325,7 @@ pub async fn create_append_batch_ix_data>( }; Ok(InstructionDataBatchAppendInputs { - public_inputs: AppendBatchProofInputsIx { new_root }, + new_root, compressed_proof: proof, }) } @@ -340,7 +335,7 @@ pub async fn create_nullify_batch_ix_data>( indexer: &mut I, merkle_tree_pubkey: Pubkey, ) -> Result { - let (zkp_batch_size, old_root, old_root_index, leaves_hashchain) = { + let (zkp_batch_size, old_root, leaves_hashchain) = { let mut account = rpc.get_account(merkle_tree_pubkey).await.unwrap().unwrap(); let merkle_tree = BatchedMerkleTreeAccount::state_from_bytes(account.data.as_mut_slice()).unwrap(); @@ -349,9 +344,8 @@ pub async fn create_nullify_batch_ix_data>( let batch = &merkle_tree.batches[batch_idx]; let zkp_idx = batch.get_num_inserted_zkps(); let hashchain = merkle_tree.hashchain_store[batch_idx][zkp_idx as usize]; - let root_idx = merkle_tree.root_history.last_index(); let root = *merkle_tree.root_history.last().unwrap(); - (zkp_size, root, root_idx, hashchain) + (zkp_size, root, hashchain) }; let leaf_indices_tx_hashes = @@ -434,10 +428,7 @@ pub async fn create_nullify_batch_ix_data>( }; Ok(InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root, - old_root_index: old_root_index as u16, - }, + new_root, compressed_proof: proof, }) } diff --git a/program-libs/batched-merkle-tree/src/batch.rs b/program-libs/batched-merkle-tree/src/batch.rs index e5dbb787e3..b3044f0bed 100644 --- a/program-libs/batched-merkle-tree/src/batch.rs +++ b/program-libs/batched-merkle-tree/src/batch.rs @@ -10,7 +10,7 @@ use crate::errors::BatchedMerkleTreeError; #[repr(u64)] pub enum BatchState { /// Batch can be filled with values. - CanBeFilled, + Fill, /// Batch has been inserted into the tree. Inserted, /// Batch is full, and insertion is in progress. @@ -20,7 +20,7 @@ pub enum BatchState { impl From for BatchState { fn from(value: u64) -> Self { match value { - 0 => BatchState::CanBeFilled, + 0 => BatchState::Fill, 1 => BatchState::Inserted, 2 => BatchState::Full, _ => panic!("Invalid BatchState value"), @@ -73,7 +73,7 @@ impl Batch { bloom_filter_capacity, batch_size, num_inserted: 0, - state: BatchState::CanBeFilled.into(), + state: BatchState::Fill.into(), zkp_batch_size, current_zkp_batch_index: 0, num_inserted_zkps: 0, @@ -102,9 +102,9 @@ impl Batch { } /// fill -> full -> inserted -> fill - pub fn advance_state_to_can_be_filled(&mut self) -> Result<(), BatchedMerkleTreeError> { + pub fn advance_state_to_fill(&mut self) -> Result<(), BatchedMerkleTreeError> { if self.get_state() == BatchState::Inserted { - self.state = BatchState::CanBeFilled.into(); + self.state = BatchState::Fill.into(); } else { msg!( "Batch is in incorrect state {} expected Inserted 3", @@ -131,7 +131,7 @@ impl Batch { /// fill -> full -> inserted -> fill pub fn advance_state_to_full(&mut self) -> Result<(), BatchedMerkleTreeError> { - if self.get_state() == BatchState::CanBeFilled { + if self.get_state() == BatchState::Fill { self.state = BatchState::Full.into(); } else { msg!( @@ -174,7 +174,7 @@ impl Batch { value: &[u8; 32], value_store: &mut ZeroCopyVecU64<[u8; 32]>, ) -> Result<(), BatchedMerkleTreeError> { - if self.get_state() != BatchState::CanBeFilled { + if self.get_state() != BatchState::Fill { return Err(BatchedMerkleTreeError::BatchNotReady); } value_store.push(*value)?; @@ -253,23 +253,31 @@ impl Batch { self.batch_size / self.zkp_batch_size } + /// Marks the batch as inserted in the merkle tree. + /// 1. Checks that the batch is ready. + /// 2. increments the number of inserted zkps. + /// 3. If all zkps are inserted, sets the state to inserted. + /// 4. Returns the updated state of the batch. pub fn mark_as_inserted_in_merkle_tree( &mut self, sequence_number: u64, root_index: u32, root_history_length: u32, - ) -> Result<(), BatchedMerkleTreeError> { - // Check that batch is ready. + ) -> Result { + // 1. Check that batch is ready. self.get_first_ready_zkp_batch()?; let num_zkp_batches = self.get_num_zkp_batches(); + // 2. increments the number of inserted zkps. self.num_inserted_zkps += 1; - // Batch has been successfully inserted into the tree. + println!("num_inserted_zkps: {}", self.num_inserted_zkps); + // 3. If all zkps are inserted, sets the state to inserted. if self.num_inserted_zkps == num_zkp_batches { - self.current_zkp_batch_index = 0; - self.state = BatchState::Inserted.into(); + println!("Setting state to inserted"); self.num_inserted_zkps = 0; + self.current_zkp_batch_index = 0; + self.advance_state_to_inserted()?; // Saving sequence number and root index for the batch. // When the batch is cleared check that sequence number is greater or equal than self.sequence_number // if not advance current root index to root index @@ -277,7 +285,7 @@ impl Batch { self.root_index = root_index; } - Ok(()) + Ok(self.get_state()) } pub fn get_hashchain_store_len(&self) -> usize { @@ -513,7 +521,7 @@ mod tests { let mut batch = get_test_batch(); assert_eq!(batch.get_num_zkp_batches(), 5); assert_eq!(batch.get_hashchain_store_len(), 5); - assert_eq!(batch.get_state(), BatchState::CanBeFilled); + assert_eq!(batch.get_state(), BatchState::Fill); assert_eq!(batch.get_num_inserted(), 0); assert_eq!(batch.get_current_zkp_batch_index(), 0); assert_eq!(batch.get_num_inserted_zkps(), 0); diff --git a/program-libs/batched-merkle-tree/src/batch_metadata.rs b/program-libs/batched-merkle-tree/src/batch_metadata.rs index ef5a62b1ff..ec2f9b0948 100644 --- a/program-libs/batched-merkle-tree/src/batch_metadata.rs +++ b/program-libs/batched-merkle-tree/src/batch_metadata.rs @@ -1,6 +1,7 @@ +use light_merkle_tree_metadata::{errors::MerkleTreeMetadataError, queue::QueueType}; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; -use crate::{BorshDeserialize, BorshSerialize}; +use crate::{batch::BatchState, errors::BatchedMerkleTreeError, BorshDeserialize, BorshSerialize}; #[repr(C)] #[derive( @@ -17,12 +18,19 @@ use crate::{BorshDeserialize, BorshSerialize}; Immutable, )] pub struct BatchMetadata { + /// Number of batches. pub num_batches: u64, + /// Number of elements in a batch. pub batch_size: u64, + /// Number of elements in a ZKP batch. + /// A batch has one or more ZKP batches. pub zkp_batch_size: u64, + /// Bloom filter capacity. + pub bloom_filter_capacity: u64, + /// Batch elements are currently inserted in. pub currently_processing_batch_index: u64, + /// Next batch to be inserted into the tree. pub next_full_batch_index: u64, - pub bloom_filter_capacity: u64, } impl BatchMetadata { @@ -30,15 +38,22 @@ impl BatchMetadata { self.batch_size / self.zkp_batch_size } - pub fn new_output_queue(batch_size: u64, zkp_batch_size: u64, num_batches: u64) -> Self { - BatchMetadata { + pub fn new_output_queue( + batch_size: u64, + zkp_batch_size: u64, + num_batches: u64, + ) -> Result { + if batch_size % zkp_batch_size != 0 { + return Err(BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize); + } + Ok(BatchMetadata { num_batches, zkp_batch_size, batch_size, currently_processing_batch_index: 0, next_full_batch_index: 0, bloom_filter_capacity: 0, - } + }) } pub fn new_input_queue( @@ -46,14 +61,83 @@ impl BatchMetadata { bloom_filter_capacity: u64, zkp_batch_size: u64, num_batches: u64, - ) -> Self { - BatchMetadata { + ) -> Result { + if batch_size % zkp_batch_size != 0 { + return Err(BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize); + } + Ok(BatchMetadata { num_batches, zkp_batch_size, batch_size, currently_processing_batch_index: 0, next_full_batch_index: 0, bloom_filter_capacity, + }) + } + + /// Increment the next full batch index if current state is inserted. + pub fn increment_next_full_batch_index_if_inserted(&mut self, state: BatchState) { + if state == BatchState::Inserted { + self.next_full_batch_index += 1; + self.next_full_batch_index %= self.num_batches; } } + + pub fn init( + &mut self, + num_batches: u64, + batch_size: u64, + zkp_batch_size: u64, + ) -> Result<(), BatchedMerkleTreeError> { + self.num_batches = num_batches; + self.batch_size = batch_size; + // Check that batch size is divisible by zkp_batch_size. + if batch_size % zkp_batch_size != 0 { + return Err(BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize); + } + self.zkp_batch_size = zkp_batch_size; + Ok(()) + } + + pub fn get_size_parameters( + &self, + queue_type: u64, + ) -> Result<(usize, usize, usize), MerkleTreeMetadataError> { + let num_batches = self.num_batches as usize; + // Input queues don't store values. + let num_value_stores = if queue_type == QueueType::BatchedOutput as u64 { + num_batches + } else if queue_type == QueueType::BatchedInput as u64 { + 0 + } else { + return Err(MerkleTreeMetadataError::InvalidQueueType); + }; + // Output queues don't use bloom filters. + let num_stores = if queue_type == QueueType::BatchedInput as u64 { + num_batches + } else if queue_type == QueueType::BatchedOutput as u64 && self.bloom_filter_capacity == 0 { + 0 + } else { + return Err(MerkleTreeMetadataError::InvalidQueueType); + }; + Ok((num_value_stores, num_stores, num_batches)) + } +} + +#[test] +fn test_increment_next_full_batch_index_if_inserted() { + // create a new metadata struct + let mut metadata = BatchMetadata::new_input_queue(10, 10, 10, 2).unwrap(); + assert_eq!(metadata.next_full_batch_index, 0); + // increment next full batch index + metadata.increment_next_full_batch_index_if_inserted(BatchState::Inserted); + assert_eq!(metadata.next_full_batch_index, 1); + // increment next full batch index + metadata.increment_next_full_batch_index_if_inserted(BatchState::Inserted); + assert_eq!(metadata.next_full_batch_index, 0); + // try incrementing next full batch index with state not inserted + metadata.increment_next_full_batch_index_if_inserted(BatchState::Fill); + assert_eq!(metadata.next_full_batch_index, 0); + metadata.increment_next_full_batch_index_if_inserted(BatchState::Full); + assert_eq!(metadata.next_full_batch_index, 0); } diff --git a/program-libs/batched-merkle-tree/src/event.rs b/program-libs/batched-merkle-tree/src/event.rs index ee5b19ac40..8d7dcdad27 100644 --- a/program-libs/batched-merkle-tree/src/event.rs +++ b/program-libs/batched-merkle-tree/src/event.rs @@ -25,3 +25,5 @@ pub struct BatchNullifyEvent { pub sequence_number: u64, pub batch_size: u64, } + +pub type BatchAddressAppendEvent = BatchNullifyEvent; diff --git a/program-libs/batched-merkle-tree/src/initialize_state_tree.rs b/program-libs/batched-merkle-tree/src/initialize_state_tree.rs index 7c69e603b7..fd16160269 100644 --- a/program-libs/batched-merkle-tree/src/initialize_state_tree.rs +++ b/program-libs/batched-merkle-tree/src/initialize_state_tree.rs @@ -477,7 +477,8 @@ pub fn create_output_queue_account(params: CreateOutputQueueParams) -> BatchedQu params.batch_size, params.zkp_batch_size, params.num_batches, - ); + ) + .unwrap(); BatchedQueueMetadata { metadata, batch_metadata, diff --git a/program-libs/batched-merkle-tree/src/merkle_tree.rs b/program-libs/batched-merkle-tree/src/merkle_tree.rs index a94152905e..0576f1ee9a 100644 --- a/program-libs/batched-merkle-tree/src/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/src/merkle_tree.rs @@ -38,7 +38,7 @@ use crate::{ BATCHED_STATE_TREE_TYPE, DEFAULT_BATCH_STATE_TREE_HEIGHT, TEST_DEFAULT_BATCH_SIZE, }, errors::BatchedMerkleTreeError, - event::{BatchAppendEvent, BatchNullifyEvent}, + event::{BatchAddressAppendEvent, BatchAppendEvent, BatchNullifyEvent}, initialize_address_tree::InitAddressTreeAccountsInstructionData, initialize_state_tree::InitStateTreeAccountsInstructionData, queue::BatchedQueueAccount, @@ -229,13 +229,13 @@ impl BatchedMerkleTreeMetadata { bloom_filter_capacity, zkp_batch_size, num_batches, - ), + ) + .unwrap(), capacity: 2u64.pow(height), } } } -#[repr(C)] #[derive(Debug, PartialEq)] pub struct BatchedMerkleTreeAccount<'a> { metadata: Ref<&'a mut [u8], BatchedMerkleTreeMetadata>, @@ -260,38 +260,34 @@ impl DerefMut for BatchedMerkleTreeAccount<'_> { } } -/// Get batch from account. -/// Hash all public inputs into one poseidon hash. /// Public inputs: -/// 1. old root (get from account by index) -/// 2. new root (send to chain and ) -/// 3. start index (get from batch) -/// 4. end index (get from batch start index plus batch size) +/// 1. old root (last root in root history) +/// 2. new root (send to chain) +/// 3. leaf hash chain (in hashchain store) #[repr(C)] #[derive(Debug, PartialEq, Clone, Copy, BorshDeserialize, BorshSerialize)] pub struct InstructionDataBatchNullifyInputs { - pub public_inputs: BatchProofInputsIx, + pub new_root: [u8; 32], pub compressed_proof: CompressedProof, } -#[repr(C)] -#[derive(Debug, PartialEq, Clone, Copy, BorshDeserialize, BorshSerialize)] -pub struct BatchProofInputsIx { - pub new_root: [u8; 32], - pub old_root_index: u16, -} +/// Public inputs: +/// 1. old root (last root in root history) +/// 2. new root (send to chain) +/// 3. leaf hash chain (in hashchain store) +/// 4. next index (get from metadata) +pub type InstructionDataAddressAppendInputs = InstructionDataBatchNullifyInputs; +/// Public inputs: +/// 1. old root (last root in root history) +/// 2. new root (send to chain) +/// 3. leaf hash chain (in hashchain store) +/// 4. start index (get from batch) #[repr(C)] #[derive(Debug, PartialEq, Clone, Copy, BorshDeserialize, BorshSerialize)] pub struct InstructionDataBatchAppendInputs { - pub public_inputs: AppendBatchProofInputsIx, - pub compressed_proof: CompressedProof, -} - -#[repr(C)] -#[derive(Debug, PartialEq, Clone, Copy, BorshDeserialize, BorshSerialize)] -pub struct AppendBatchProofInputsIx { pub new_root: [u8; 32], + pub compressed_proof: CompressedProof, } impl<'a> BatchedMerkleTreeAccount<'a> { @@ -431,8 +427,9 @@ impl<'a> BatchedMerkleTreeAccount<'a> { if tree_type == TreeType::BatchedState { root_history.push(light_hasher::Poseidon::zero_bytes()[height as usize]); } else if tree_type == TreeType::BatchedAddress { - // Initialized indexed Merkle tree root + // Initialized indexed Merkle tree root. root_history.push(ADDRESS_TREE_INIT_ROOT_40); + // The initialized indexed Merkle tree contains two elements. account_metadata.next_index = 2; } let (batches, value_vecs, bloom_filter_stores, hashchain_store) = init_queue( @@ -453,210 +450,270 @@ impl<'a> BatchedMerkleTreeAccount<'a> { }) } + /// Update the tree from the output queue account. + /// 1. Checks that the tree and queue are associated. + /// 2. Updates the tree with the output queue account. + /// 3. Returns the batch append event. pub fn update_tree_from_output_queue_account_info( &mut self, queue_account_info: &AccountInfo<'_>, instruction_data: InstructionDataBatchAppendInputs, id: [u8; 32], ) -> Result { + if self.tree_type != TreeType::BatchedState as u64 { + return Err(MerkleTreeMetadataError::InvalidTreeType.into()); + } if self.metadata.metadata.associated_queue != (*queue_account_info.key).into() { return Err(MerkleTreeMetadataError::MerkleTreeAndQueueNotAssociated.into()); } let queue_account = &mut BatchedQueueAccount::output_from_account_info(queue_account_info)?; - self.update_output_queue_account(queue_account, instruction_data, id) + self.update_tree_from_output_queue_account(queue_account, instruction_data, id) } - // Note: when proving inclusion by index in - // value array we need to insert the value into a bloom_filter once it is - // inserted into the tree. Check this with get_num_inserted_zkps - pub fn update_output_queue_account( + /// Update the tree from the output queue account. + /// 1. Create public inputs hash. + /// 2. Verify update proof and update tree account. + /// 2.1. Verify proof. + /// 2.2. Increment sequence number. + /// 2.3. Increment next index. + /// 2.4. Append new root to root history. + /// 3. Mark batch as inserted in the merkle tree. + /// 3.1. Checks that the batch is ready. + /// 3.2. Increment the number of inserted zkps. + /// 3.3. If all zkps are inserted, set the state to inserted. + /// 4. Increment next full batch index if inserted. + /// 5. Return the batch append event. + /// Note: when proving inclusion by index in + /// value array we need to insert the value into a bloom_filter once it is + /// inserted into the tree. Check this with get_num_inserted_zkps + #[cfg(not(target_os = "solana"))] + pub fn update_tree_from_output_queue_account( &mut self, queue_account: &mut BatchedQueueAccount, instruction_data: InstructionDataBatchAppendInputs, id: [u8; 32], ) -> Result { - let batch_index = queue_account.batch_metadata.next_full_batch_index; + let full_batch_index = queue_account.batch_metadata.next_full_batch_index as usize; + let new_root = instruction_data.new_root; let circuit_batch_size = queue_account.batch_metadata.zkp_batch_size; - let batches = &mut queue_account.batches; - let full_batch = batches - .get_mut(batch_index as usize) - .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; - - let new_root = instruction_data.public_inputs.new_root; + let start_index = self.next_index; + let full_batch = &mut queue_account.batches[full_batch_index]; let num_zkps = full_batch.get_first_ready_zkp_batch()?; - let leaves_hashchain = queue_account - .hashchain_store - .get(batch_index as usize) - .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)? - .get(num_zkps as usize) - .ok_or(BatchedMerkleTreeError::InvalidIndex)?; - let old_root = self - .root_history - .last() - .ok_or(BatchedMerkleTreeError::InvalidIndex)?; - let start_index = self.next_index; - let mut start_index_bytes = [0u8; 32]; - start_index_bytes[24..].copy_from_slice(&start_index.to_be_bytes()); - let public_input_hash = create_hash_chain_from_array([ - *old_root, - new_root, - *leaves_hashchain, - start_index_bytes, - ])?; + // 1. Create public inputs hash. + let public_input_hash = { + let leaves_hashchain = + queue_account.hashchain_store[full_batch_index][num_zkps as usize]; + let old_root = self + .root_history + .last() + .ok_or(BatchedMerkleTreeError::InvalidIndex)?; + let mut start_index_bytes = [0u8; 32]; + start_index_bytes[24..].copy_from_slice(&start_index.to_be_bytes()); + create_hash_chain_from_array([ + *old_root, + new_root, + leaves_hashchain, + start_index_bytes, + ])? + }; - self.update::<5>( - circuit_batch_size as usize, + // 2. Verify update proof and update tree account. + self.verify_update::<5>( + circuit_batch_size, instruction_data.compressed_proof, public_input_hash, + new_root, )?; - self.metadata.next_index += circuit_batch_size; - let root_history_capacity = self.metadata.root_history_capacity; - let sequence_number = self.metadata.sequence_number; - self.root_history.push(new_root); let root_index = self.root_history.last_index() as u32; - full_batch.mark_as_inserted_in_merkle_tree( - sequence_number, - root_index, - root_history_capacity, - )?; - if full_batch.get_state() == BatchState::Inserted { - queue_account.batch_metadata.next_full_batch_index += 1; - queue_account.batch_metadata.next_full_batch_index %= - queue_account.batch_metadata.num_batches; + + // Update queue metadata. + { + // 3. Mark batch as inserted in the merkle tree. + let full_batch_state = full_batch.mark_as_inserted_in_merkle_tree( + self.metadata.sequence_number, + root_index, + self.metadata.root_history_capacity, + )?; + // 4. Increment next full batch index if inserted. + queue_account + .batch_metadata + .increment_next_full_batch_index_if_inserted(full_batch_state); } + // 5. Return the batch append event. Ok(BatchAppendEvent { id, - batch_index, - batch_size: circuit_batch_size, + batch_index: full_batch_index as u64, zkp_batch_index: num_zkps, old_next_index: start_index, new_next_index: start_index + circuit_batch_size, + batch_size: circuit_batch_size, new_root, root_index, sequence_number: self.sequence_number, }) } + /// Update the tree from the input queue account. pub fn update_tree_from_input_queue( &mut self, instruction_data: InstructionDataBatchNullifyInputs, id: [u8; 32], ) -> Result { - self.private_update_input_queue::<3>(instruction_data, id) + if self.tree_type != TreeType::BatchedState as u64 { + return Err(MerkleTreeMetadataError::InvalidTreeType.into()); + } + self.update_input_queue::<3>(instruction_data, id) } + /// Update the tree from the address queue account. pub fn update_tree_from_address_queue( &mut self, - instruction_data: InstructionDataBatchNullifyInputs, + instruction_data: InstructionDataAddressAppendInputs, id: [u8; 32], - ) -> Result { - self.private_update_input_queue::<4>(instruction_data, id) + ) -> Result { + if self.tree_type != TreeType::BatchedAddress as u64 { + return Err(MerkleTreeMetadataError::InvalidTreeType.into()); + } + self.update_input_queue::<4>(instruction_data, id) } - fn private_update_input_queue( + /// Update the tree from the input/address queue account. + /// 1. Create public inputs hash. + /// 2. Verify update proof and update tree account. + /// 2.1. Verify proof. + /// 2.2. Increment sequence number. + /// 2.3. If address tree increment next index. + /// 2.4. Append new root to root history. + /// 3. Mark batch as inserted in the merkle tree. + /// 3.1. Checks that the batch is ready. + /// 3.2. Increment the number of inserted zkps. + /// 3.3. If all zkps are inserted, set the state to inserted. + /// 4. Wipe previous batch bloom filter if current batch is 50% inserted. + /// 5. Increment next full batch index if inserted. + /// 6. Return the batch nullify event. + fn update_input_queue( &mut self, instruction_data: InstructionDataBatchNullifyInputs, id: [u8; 32], ) -> Result { - let batch_index = self.queue_metadata.next_full_batch_index; - - let full_batch = self - .batches - .get(batch_index as usize) - .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; - - let num_zkps = full_batch.get_first_ready_zkp_batch()?; + let full_batch_index = self.queue_metadata.next_full_batch_index as usize; + let num_zkps = self.batches[full_batch_index].get_first_ready_zkp_batch()?; + let new_root = instruction_data.new_root; + let circuit_batch_size = self.queue_metadata.zkp_batch_size; - let leaves_hashchain = self - .hashchain_store - .get(batch_index as usize) - .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)? - .get(num_zkps as usize) - .ok_or(BatchedMerkleTreeError::InvalidIndex)?; - let old_root = self - .root_history - .get(instruction_data.public_inputs.old_root_index as usize) - .ok_or(BatchedMerkleTreeError::InvalidIndex)?; - let new_root = instruction_data.public_inputs.new_root; - - let public_input_hash = if QUEUE_TYPE == QueueType::BatchedInput as u64 { - create_hash_chain_from_array([*old_root, new_root, *leaves_hashchain])? - } else if QUEUE_TYPE == QueueType::BatchedAddress as u64 { - let mut next_index_bytes = [0u8; 32]; - next_index_bytes[24..].copy_from_slice(self.next_index.to_be_bytes().as_slice()); - create_hash_chain_from_array([ - *old_root, - new_root, - *leaves_hashchain, - next_index_bytes, - ])? - } else { - return Err(MerkleTreeMetadataError::InvalidQueueType.into()); + // 1. Create public inputs hash. + let public_input_hash = { + let leaves_hashchain = self.hashchain_store[full_batch_index][num_zkps as usize]; + let old_root = self + .root_history + .last() + .ok_or(BatchedMerkleTreeError::InvalidIndex)?; + + if QUEUE_TYPE == QueueType::BatchedInput as u64 { + create_hash_chain_from_array([*old_root, new_root, leaves_hashchain])? + } else if QUEUE_TYPE == QueueType::BatchedAddress as u64 { + let mut next_index_bytes = [0u8; 32]; + next_index_bytes[24..].copy_from_slice(self.next_index.to_be_bytes().as_slice()); + create_hash_chain_from_array([ + *old_root, + new_root, + leaves_hashchain, + next_index_bytes, + ])? + } else { + return Err(MerkleTreeMetadataError::InvalidQueueType.into()); + } }; - let circuit_batch_size = self.queue_metadata.zkp_batch_size; - self.update::( - circuit_batch_size as usize, + + // 2. Verify update proof and update tree account. + self.verify_update::( + circuit_batch_size, instruction_data.compressed_proof, public_input_hash, + new_root, )?; - self.root_history.push(new_root); - let root_history_capacity = self.root_history_capacity; - let sequence_number = self.sequence_number; - let full_batch = self - .batches - .get_mut(batch_index as usize) - .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; - full_batch.mark_as_inserted_in_merkle_tree( - sequence_number, - self.root_history.last_index() as u32, - root_history_capacity, - )?; - if full_batch.get_state() == BatchState::Inserted { - self.metadata.queue_metadata.next_full_batch_index += 1; - self.metadata.queue_metadata.next_full_batch_index %= - self.metadata.queue_metadata.num_batches; - } - if QUEUE_TYPE == QueueType::BatchedAddress as u64 { - self.metadata.next_index += circuit_batch_size; - } + let root_index = self.root_history.last_index() as u32; - self.wipe_previous_batch_bloom_filter()?; + // Update queue metadata. + { + let root_history_capacity = self.root_history_capacity; + let sequence_number = self.sequence_number; + // 3. Mark batch as inserted in the merkle tree. + let full_batch_state = self.batches[full_batch_index].mark_as_inserted_in_merkle_tree( + sequence_number, + root_index, + root_history_capacity, + )?; + + // 4. Wipe previous batch bloom filter + // if current batch is 50% inserted. + // Needs to be executed prior to + // incrementing next full batch index, + // but post mark_as_inserted_in_merkle_tree. + self.wipe_previous_batch_bloom_filter()?; + + // 5. Increment next full batch index if inserted. + self.metadata + .queue_metadata + .increment_next_full_batch_index_if_inserted(full_batch_state); + } + // 6. Return the batch nullify/address append event. Ok(BatchNullifyEvent { id, - batch_index, + batch_index: full_batch_index as u64, batch_size: circuit_batch_size, zkp_batch_index: num_zkps, new_root, - root_index: self.root_history.last_index() as u32, + root_index, sequence_number: self.sequence_number, }) } - fn update( + /// Verify update proof and update the tree. + /// 1. Verify update proof. + /// 2. Increment next index (unless queue type is BatchedInput). + /// 3. Increment sequence number. + /// 4. Append new root to root history. + fn verify_update( &mut self, - batch_size: usize, + batch_size: u64, proof: CompressedProof, public_input_hash: [u8; 32], + new_root: [u8; 32], ) -> Result<(), BatchedMerkleTreeError> { + // 1. Verify update proof. if QUEUE_TYPE == QueueType::BatchedOutput as u64 { verify_batch_append_with_proofs(batch_size, public_input_hash, &proof)?; + // 2. Increment next index. + self.metadata.next_index += batch_size; } else if QUEUE_TYPE == QueueType::BatchedInput as u64 { verify_batch_update(batch_size, public_input_hash, &proof)?; + // 2. skip incrementing next index. + // The input queue update does not append new values + // hence no need to increment next_index. } else if QUEUE_TYPE == QueueType::BatchedAddress as u64 { verify_batch_address_update(batch_size, public_input_hash, &proof)?; + // 2. Increment next index. + self.metadata.next_index += batch_size; } else { return Err(MerkleTreeMetadataError::InvalidQueueType.into()); } + // 3. Increment sequence number. self.metadata.sequence_number += 1; + // 4. Append new root to root history. + // root_history is a cyclic vec + // it will overwrite the oldest root + // once it is full. + self.root_history.push(new_root); Ok(()) } /// State nullification: - /// - value is committed to bloom_filter for non-inclusion proof + /// - value is inserted into to a bloom_filter to prove non-inclusion in later txs. /// - nullifier is Hash(value, tx_hash), committed to leaves hashchain /// - tx_hash is hash of all inputs and outputs /// -> we can access the history of how commitments are spent in zkps for example fraud proofs @@ -668,7 +725,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { ) -> Result<(), BatchedMerkleTreeError> { // Note, no need to check whether the tree is full // since nullifier insertions update existing values - // in the tree and not append a values. + // in the tree and do not append new values. if self.tree_type != TreeType::BatchedState as u64 { return Err(MerkleTreeMetadataError::InvalidTreeType.into()); } @@ -684,7 +741,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { if self.tree_type != TreeType::BatchedAddress as u64 { return Err(MerkleTreeMetadataError::InvalidTreeType.into()); } - // Check if the tree is full. + // Check that the tree is not full. self.check_tree_is_full()?; self.insert_into_current_batch(address, address) @@ -755,28 +812,49 @@ impl<'a> BatchedMerkleTreeAccount<'a> { Ok(()) } + /// Zero out roots corresponding to sequence numbers < sequence_number. + /// 1. Check whether overlapping roots exist. + /// 2. If yes: + /// 2.1 Get, first safe root index. + /// 2.2 Zero out roots from the oldest root to first safe root. fn zero_out_roots(&mut self, sequence_number: u64, root_index: Option) { - if sequence_number > self.sequence_number { - // advance root history array current index from latest root - // to root_index and overwrite all roots with zeros + // 1. Check whether overlapping roots exist. + let overlapping_roots_exits = sequence_number > self.sequence_number; + if overlapping_roots_exits { if let Some(root_index) = root_index { let root_index = root_index as usize; - let start = self.root_history.last_index(); - let end = self.root_history.len() + root_index; - for index in start + 1..end { + + let oldest_root_index = self.root_history.first_index(); + // 2.1. Get, index of first root inserted after input queue batch was inserted. + let first_safe_root_index = self.root_history.len() + root_index; + // 2.2. Zero out roots oldest to first safe root index. + for index in oldest_root_index..first_safe_root_index { let index = index % self.root_history.len(); + // TODO: test if needed if index == root_index { break; } - let root = self.root_history.get_mut(index).unwrap(); - *root = [0u8; 32]; + self.root_history[index] = [0u8; 32]; } + } else { + unreachable!("root_index must be Some(root_index) if overlapping roots exist"); } } } - /// Wipe bloom filter after a batch has been inserted and 50% of the - /// subsequent batch been processed. + /// Wipe bloom filter of previous batch if 50% of the + /// current batch has been processed. + /// + /// Idea: + /// 1. Wiping the bloom filter of the previous batch is expensive + /// -> the forester should do it. + /// 2. We don't want to wipe the bloom filter when inserting + /// the last zkp of a batch for this might result in failing user tx. + /// 3. Wait until next batch is 50% full as grace period for clients + /// to switch from proof by index to proof by zkp + /// for values inserted in the previous batch. + /// + /// Steps: /// 1. Previous batch must be inserted and bloom filter must not be wiped. /// 2. Current batch must be 50% full /// 3. if yes @@ -784,32 +862,65 @@ impl<'a> BatchedMerkleTreeAccount<'a> { /// 3.2 mark bloom filter as wiped /// 3.3 zero out roots if needed pub fn wipe_previous_batch_bloom_filter(&mut self) -> Result<(), BatchedMerkleTreeError> { - let current_batch = self.queue_metadata.currently_processing_batch_index; + let current_batch = self.queue_metadata.next_full_batch_index as usize; let batch_size = self.queue_metadata.batch_size; - let previous_full_batch_index = - self.queue_metadata.next_full_batch_index.saturating_sub(1) as usize; - let num_inserted_elements = self - .batches - .get(current_batch as usize) - .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)? - .get_num_inserted_elements(); + let previous_full_batch_index = current_batch.saturating_sub(1); + let previous_full_batch_index = if previous_full_batch_index == current_batch { + self.queue_metadata.num_batches as usize - 1 + } else { + previous_full_batch_index + }; + + let current_batch_is_half_full = { + let num_inserted_elements = self + .batches + .get(current_batch) + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)? + .get_num_inserted_elements(); + // Keep for finegrained unit test + println!("current_batch: {}", current_batch); + println!("previous_full_batch_index: {}", previous_full_batch_index); + println!("num_inserted_elements: {}", num_inserted_elements); + println!("batch_size: {}", batch_size); + println!("batch_size / 2: {}", batch_size / 2); + println!( + "num_inserted_elements >= batch_size / 2: {}", + num_inserted_elements >= batch_size / 2 + ); + num_inserted_elements >= batch_size / 2 + }; + let previous_full_batch = self .batches .get_mut(previous_full_batch_index) .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; - if previous_full_batch.get_state() == BatchState::Inserted - && batch_size / 2 > num_inserted_elements - && !previous_full_batch.bloom_filter_is_wiped() - { - let bloom_filter = self - .bloom_filter_stores - .get_mut(previous_full_batch_index) - .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; - bloom_filter.as_mut_slice().iter_mut().for_each(|x| *x = 0); + + let previous_batch_is_ready = previous_full_batch.get_state() == BatchState::Inserted + && !previous_full_batch.bloom_filter_is_wiped(); + + if previous_batch_is_ready && current_batch_is_half_full { + // Keep for finegrained unit test + println!("Wiping bloom filter of previous batch"); + println!("current_batch: {}", current_batch); + println!("previous_full_batch_index: {}", previous_full_batch_index); + // 3.1 Zero out bloom filter. + { + let bloom_filter = self + .bloom_filter_stores + .get_mut(previous_full_batch_index) + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; + bloom_filter.as_mut_slice().iter_mut().for_each(|x| *x = 0); + } + // 3.2 Mark bloom filter wiped. previous_full_batch.set_bloom_filter_is_wiped(); - let seq = previous_full_batch.sequence_number; - let root_index = previous_full_batch.root_index; - self.zero_out_roots(seq, Some(root_index)); + // 3.3 Zero out roots if a root exists in root history + // which allows to prove inclusion of a value + // that was inserted into the bloom filter just wiped. + { + let seq = previous_full_batch.sequence_number; + let root_index = previous_full_batch.root_index; + self.zero_out_roots(seq, Some(root_index)); + } } Ok(()) diff --git a/program-libs/batched-merkle-tree/src/queue.rs b/program-libs/batched-merkle-tree/src/queue.rs index ee9543341a..b9aa36881e 100644 --- a/program-libs/batched-merkle-tree/src/queue.rs +++ b/program-libs/batched-merkle-tree/src/queue.rs @@ -76,48 +76,6 @@ impl BatchedQueueMetadata { } } -impl BatchMetadata { - pub fn init( - &mut self, - num_batches: u64, - batch_size: u64, - zkp_batch_size: u64, - ) -> Result<(), BatchedMerkleTreeError> { - self.num_batches = num_batches; - self.batch_size = batch_size; - // Check that batch size is divisible by zkp_batch_size. - if batch_size % zkp_batch_size != 0 { - return Err(BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize); - } - self.zkp_batch_size = zkp_batch_size; - Ok(()) - } - - pub fn get_size_parameters( - &self, - queue_type: u64, - ) -> Result<(usize, usize, usize), MerkleTreeMetadataError> { - let num_batches = self.num_batches as usize; - // Input queues don't store values - let num_value_stores = if queue_type == QueueType::BatchedOutput as u64 { - num_batches - } else if queue_type == QueueType::BatchedInput as u64 { - 0 - } else { - return Err(MerkleTreeMetadataError::InvalidQueueType); - }; - // Output queues don't use bloom filters. - let num_stores = if queue_type == QueueType::BatchedInput as u64 { - num_batches - } else if queue_type == QueueType::BatchedOutput as u64 && self.bloom_filter_capacity == 0 { - 0 - } else { - return Err(MerkleTreeMetadataError::InvalidQueueType); - }; - Ok((num_value_stores, num_stores, num_batches)) - } -} - pub fn queue_account_size( batch_metadata: &BatchMetadata, queue_type: u64, @@ -148,9 +106,8 @@ pub fn queue_account_size( + hashchain_store_size; Ok(size) } -/// Batched output queue -#[repr(C)] -#[derive(Debug)] + +#[derive(Debug, PartialEq)] pub struct BatchedQueueAccount<'a> { metadata: Ref<&'a mut [u8], BatchedQueueMetadata>, pub batches: ZeroCopySliceMutU64<'a, Batch>, @@ -199,24 +156,17 @@ impl<'a> BatchedQueueAccount<'a> { let account_data: &'a mut [u8] = unsafe { std::slice::from_raw_parts_mut(account_data.as_mut_ptr(), account_data.len()) }; - Self::internal_from_bytes::(account_data) + Self::from_bytes::(account_data) } #[cfg(not(target_os = "solana"))] pub fn output_from_bytes( account_data: &'a mut [u8], ) -> Result, BatchedMerkleTreeError> { - Self::internal_from_bytes::(account_data) - } - - #[cfg(not(target_os = "solana"))] - pub fn from_bytes( - account_data: &'a mut [u8], - ) -> Result, BatchedMerkleTreeError> { - Self::internal_from_bytes::(account_data) + Self::from_bytes::(account_data) } - fn internal_from_bytes( + fn from_bytes( account_data: &'a mut [u8], ) -> Result, BatchedMerkleTreeError> { let (discriminator, account_data) = account_data.split_at_mut(DISCRIMINATOR_LEN); @@ -445,7 +395,7 @@ pub fn insert_into_current_batch( .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; let mut wipe = false; if current_batch.get_state() == BatchState::Inserted { - current_batch.advance_state_to_can_be_filled()?; + current_batch.advance_state_to_fill()?; if let Some(current_index) = current_index { current_batch.start_index = current_index; } diff --git a/program-libs/batched-merkle-tree/tests/merkle_tree.rs b/program-libs/batched-merkle-tree/tests/merkle_tree.rs index f71449561e..b079675d27 100644 --- a/program-libs/batched-merkle-tree/tests/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/tests/merkle_tree.rs @@ -18,9 +18,8 @@ use light_batched_merkle_tree::{ }, merkle_tree::{ assert_batch_append_event_event, assert_nullify_event, - get_merkle_tree_account_size_default, AppendBatchProofInputsIx, BatchProofInputsIx, - BatchedMerkleTreeAccount, BatchedMerkleTreeMetadata, InstructionDataBatchAppendInputs, - InstructionDataBatchNullifyInputs, + get_merkle_tree_account_size_default, BatchedMerkleTreeAccount, BatchedMerkleTreeMetadata, + InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, }, queue::{ get_output_queue_account_size_default, get_output_queue_account_size_from_params, @@ -155,7 +154,7 @@ pub fn assert_input_queue_insert( println!("assert input queue batch update: clearing batch"); pre_hashchains[inserted_batch_index].clear(); expected_batch.sequence_number = 0; - expected_batch.advance_state_to_can_be_filled().unwrap(); + expected_batch.advance_state_to_fill().unwrap(); expected_batch.set_bloom_filter_is_not_wiped(); } println!( @@ -276,7 +275,7 @@ pub fn assert_output_queue_insert( let pre_value_store = pre_value_store.get_mut(inserted_batch_index).unwrap(); let pre_hashchain = pre_hashchains.get_mut(inserted_batch_index).unwrap(); if expected_batch.get_state() == BatchState::Inserted { - expected_batch.advance_state_to_can_be_filled().unwrap(); + expected_batch.advance_state_to_fill().unwrap(); pre_value_store.clear(); pre_hashchain.clear(); expected_batch.start_index = pre_account.next_index; @@ -665,7 +664,6 @@ async fn test_simulate_transactions() { BatchedMerkleTreeAccount::state_from_bytes(&mut pre_mt_account_data).unwrap(); println!("batches {:?}", account.batches); - let old_root_index = account.root_history.last_index(); let next_full_batch = account.get_metadata().queue_metadata.next_full_batch_index; let batch = account.batches.get(next_full_batch as usize).unwrap(); println!( @@ -693,10 +691,7 @@ async fn test_simulate_transactions() { .await .unwrap(); let instruction_data = InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root, - old_root_index: old_root_index as u16, - }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -769,7 +764,7 @@ async fn test_simulate_transactions() { .unwrap(); let instruction_data = InstructionDataBatchAppendInputs { - public_inputs: AppendBatchProofInputsIx { new_root }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -782,7 +777,7 @@ async fn test_simulate_transactions() { let queue_account = &mut BatchedQueueAccount::output_from_bytes(&mut pre_output_queue_state).unwrap(); - let output_res = account.update_output_queue_account( + let output_res = account.update_tree_from_output_queue_account( queue_account, instruction_data, mt_pubkey.to_bytes(), @@ -1080,7 +1075,7 @@ async fn test_e2e() { } let instruction_data = InstructionDataBatchAppendInputs { - public_inputs: AppendBatchProofInputsIx { new_root }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -1093,7 +1088,7 @@ async fn test_e2e() { let queue_account = &mut BatchedQueueAccount::output_from_bytes(&mut pre_output_queue_state).unwrap(); - let output_res = account.update_output_queue_account( + let output_res = account.update_tree_from_output_queue_account( queue_account, instruction_data, mt_pubkey.to_bytes(), @@ -1158,7 +1153,6 @@ pub async fn perform_input_update( let (input_res, root) = { let mut account = BatchedMerkleTreeAccount::state_from_bytes(mt_account_data).unwrap(); - let old_root_index = account.root_history.last_index(); let next_full_batch = account.get_metadata().queue_metadata.next_full_batch_index; let batch = account.batches.get(next_full_batch as usize).unwrap(); let leaves_hashchain = account @@ -1175,10 +1169,7 @@ pub async fn perform_input_update( .await .unwrap(); let instruction_data = InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root, - old_root_index: old_root_index as u16, - }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -1221,7 +1212,6 @@ pub async fn perform_address_update( let (input_res, root, pre_next_full_batch) = { let mut account = BatchedMerkleTreeAccount::address_from_bytes(mt_account_data).unwrap(); - let old_root_index = account.root_history.last_index(); let next_full_batch = account.get_metadata().queue_metadata.next_full_batch_index; let next_index = account.get_metadata().next_index; println!("next index {:?}", next_index); @@ -1246,10 +1236,7 @@ pub async fn perform_address_update( .await .unwrap(); let instruction_data = InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root, - old_root_index: old_root_index as u16, - }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -1276,9 +1263,8 @@ pub async fn perform_address_update( let account = BatchedMerkleTreeAccount::address_from_bytes(mt_account_data).unwrap(); { - let next_full_batch = account.get_metadata().queue_metadata.next_full_batch_index; - let batch = account.batches.get(next_full_batch as usize).unwrap(); - if pre_next_full_batch != next_full_batch { + let batch = account.batches.get(pre_next_full_batch as usize).unwrap(); + if batch.get_state() == BatchState::Inserted { mock_indexer.finalize_batch_address_update(batch.batch_size as usize); } } @@ -1297,62 +1283,66 @@ fn assert_merkle_tree_update( let mut expected_account = *old_account.get_metadata(); expected_account.sequence_number += 1; let actual_account = *account.get_metadata(); - - let ( - batches, - previous_batchs, - _previous_processing, - expected_queue_account, - mut next_full_batch_index, - ) = if let Some(queue_account) = queue_account.as_ref() { - let expected_queue_account = *old_queue_account.as_ref().unwrap().get_metadata(); - - let previous_processing = if queue_account - .get_metadata() - .batch_metadata - .currently_processing_batch_index - == 0 - { - queue_account.get_metadata().batch_metadata.num_batches - 1 - } else { - queue_account - .get_metadata() - .batch_metadata - .currently_processing_batch_index - - 1 - }; - expected_account.next_index += queue_account.batches.get(0).unwrap().zkp_batch_size; - let next_full_batch_index = expected_queue_account.batch_metadata.next_full_batch_index; - ( - queue_account.batches.to_vec(), - old_queue_account.as_ref().unwrap().batches.to_vec(), - previous_processing, - Some(expected_queue_account), - next_full_batch_index, - ) + // We only have two batches. + let previous_full_batch_index = if expected_account.queue_metadata.next_full_batch_index == 0 { + 1 } else { - // We only have two batches. - let previous_processing = if expected_account - .queue_metadata - .currently_processing_batch_index - == 0 - { - 1 + 0 + }; + + let (batches, mut previous_batches, expected_queue_account, mut next_full_batch_index) = + if let Some(queue_account) = queue_account.as_ref() { + let expected_queue_account = *old_queue_account.as_ref().unwrap().get_metadata(); + expected_account.next_index += queue_account.batches.get(0).unwrap().zkp_batch_size; + let next_full_batch_index = expected_queue_account.batch_metadata.next_full_batch_index; + ( + queue_account.batches.to_vec(), + old_queue_account.as_ref().unwrap().batches.to_vec(), + Some(expected_queue_account), + next_full_batch_index, + ) } else { - 0 + let mut batches = old_account.batches.to_vec(); + println!("previous_full_batch_index: {:?}", previous_full_batch_index); + let previous_batch = batches.get_mut(previous_full_batch_index as usize).unwrap(); + println!("previous_batch state: {:?}", previous_batch.get_state()); + println!( + "previous_batch wiped?: {:?}", + previous_batch.bloom_filter_is_wiped() + ); + let previous_batch_is_ready = previous_batch.get_state() == BatchState::Inserted + && !previous_batch.bloom_filter_is_wiped(); + let batch = batches + .get_mut(old_account.queue_metadata.next_full_batch_index as usize) + .unwrap(); + + println!("previous_batch_is_ready: {:?}", previous_batch_is_ready); + println!( + "batch.bloom_filter_is_wiped(): {:?}", + batch.bloom_filter_is_wiped() + ); + println!( + "batch.get_num_inserted_elements(): {:?}", + batch.get_num_inserted_elements() + batch.zkp_batch_size + ); + println!("batch.batch_size: {:?}", batch.batch_size); + println!(" batch.get_num_inserted_elements() >= batch.batch_size / 2 && previous_batch_is_ready: {:?}", batch.get_num_inserted_elements()+ batch.zkp_batch_size >= batch.batch_size / 2); + let wiped_batch = batch.get_num_inserted_elements() + batch.zkp_batch_size + >= batch.batch_size / 2 + && previous_batch_is_ready; + let previous_batch = batches.get_mut(previous_full_batch_index as usize).unwrap(); + + if wiped_batch { + previous_batch.set_bloom_filter_is_wiped(); + println!("set bloom filter is wiped"); + } + (account.batches.to_vec(), batches, None, 0) }; - ( - account.batches.to_vec(), - old_account.batches.to_vec(), - previous_processing, - None, - 0, - ) - }; let mut checked_one = false; + for (i, batch) in batches.iter().enumerate() { - let previous_batch = previous_batchs.get(i).unwrap(); + let previous_batch = previous_batches.get_mut(i).unwrap(); let expected_sequence_number = account.root_history.capacity() as u64 + account.get_metadata().sequence_number; @@ -1360,6 +1350,7 @@ fn assert_merkle_tree_update( && batch.get_state() == BatchState::Inserted; let updated_batch = previous_batch.get_first_ready_zkp_batch().is_ok() && !checked_one; + // Assert fully inserted batch if batch_fully_inserted { if queue_account.is_some() { @@ -1571,7 +1562,7 @@ async fn test_fill_queues_completely() { } let instruction_data = InstructionDataBatchAppendInputs { - public_inputs: AppendBatchProofInputsIx { new_root }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -1582,7 +1573,7 @@ async fn test_fill_queues_completely() { println!("Output update -----------------------------"); let queue_account = &mut BatchedQueueAccount::output_from_bytes(&mut pre_output_queue_state).unwrap(); - let output_res = account.update_output_queue_account( + let output_res = account.update_tree_from_output_queue_account( queue_account, instruction_data, mt_pubkey.to_bytes(), @@ -1706,11 +1697,16 @@ async fn test_fill_queues_completely() { for i in 0..num_updates { println!("input update ----------------------------- {}", i); perform_input_update(&mut mt_account_data, &mut mock_indexer, false, mt_pubkey).await; - if i == 5 { + if i >= 7 { let merkle_tree_account = &mut BatchedMerkleTreeAccount::state_from_bytes(&mut mt_account_data).unwrap(); let batch = merkle_tree_account.batches.get(0).unwrap(); assert!(batch.bloom_filter_is_wiped()); + } else { + let merkle_tree_account = + &mut BatchedMerkleTreeAccount::state_from_bytes(&mut mt_account_data).unwrap(); + let batch = merkle_tree_account.batches.get(0).unwrap(); + assert!(!batch.bloom_filter_is_wiped()); } println!( "performed input queue batched update {} created root {:?}", @@ -1757,7 +1753,7 @@ async fn test_fill_queues_completely() { .unwrap(); { let post_batch = *merkle_tree_account.batches.get(0).unwrap(); - assert_eq!(post_batch.get_state(), BatchState::CanBeFilled); + assert_eq!(post_batch.get_state(), BatchState::Fill); assert_eq!(post_batch.get_num_inserted(), 1); let bloom_filter_store = merkle_tree_account.bloom_filter_stores.get_mut(0).unwrap(); @@ -1923,8 +1919,7 @@ async fn test_fill_address_tree_completely() { } // Root of the final batch of first input queue batch let mut first_input_batch_update_root_value = [0u8; 32]; - let num_updates = params.input_queue_batch_size / params.input_queue_zkp_batch_size - * params.input_queue_num_batches; + let num_updates = 10; for i in 0..num_updates { println!("address update ----------------------------- {}", i); perform_address_update(&mut mt_account_data, &mut mock_indexer, false, mt_pubkey).await; @@ -1937,7 +1932,7 @@ async fn test_fill_address_tree_completely() { let batch_one = merkle_tree_account.batches.get(1).unwrap(); assert!(!batch_one.bloom_filter_is_wiped()); - if i >= 4 { + if i >= 7 { assert!(batch.bloom_filter_is_wiped()); } else { assert!(!batch.bloom_filter_is_wiped()); @@ -1972,7 +1967,7 @@ async fn test_fill_address_tree_completely() { // .get(0) // .unwrap() // .clone(); - // assert_eq!(post_batch.get_state(), BatchState::CanBeFilled); + // assert_eq!(post_batch.get_state(), BatchState::Fill); // assert_eq!(post_batch.get_num_inserted(), 1); // let mut bloom_filter_store = merkle_tree_account // .bloom_filter_stores diff --git a/program-libs/verifier/src/lib.rs b/program-libs/verifier/src/lib.rs index c90568bc21..1d0f1bd42a 100644 --- a/program-libs/verifier/src/lib.rs +++ b/program-libs/verifier/src/lib.rs @@ -310,7 +310,7 @@ pub fn verify( #[inline(never)] pub fn verify_batch_append_with_proofs( - batch_size: usize, + batch_size: u64, public_input_hash: [u8; 32], compressed_proof: &CompressedProof, ) -> Result<(), VerifierError> { @@ -341,7 +341,7 @@ pub fn verify_batch_append_with_proofs( #[inline(never)] pub fn verify_batch_update( - batch_size: usize, + batch_size: u64, public_input_hash: [u8; 32], compressed_proof: &CompressedProof, ) -> Result<(), VerifierError> { @@ -377,7 +377,7 @@ pub fn verify_batch_update( #[inline(never)] pub fn verify_batch_address_update( - batch_size: usize, + batch_size: u64, public_input_hash: [u8; 32], compressed_proof: &CompressedProof, ) -> Result<(), VerifierError> { diff --git a/program-tests/account-compression-test/tests/batched_merkle_tree_test.rs b/program-tests/account-compression-test/tests/batched_merkle_tree_test.rs index f2664be411..66de512773 100644 --- a/program-tests/account-compression-test/tests/batched_merkle_tree_test.rs +++ b/program-tests/account-compression-test/tests/batched_merkle_tree_test.rs @@ -13,9 +13,8 @@ use light_batched_merkle_tree::{ create_output_queue_account, CreateOutputQueueParams, InitStateTreeAccountsInstructionData, }, merkle_tree::{ - get_merkle_tree_account_size, AppendBatchProofInputsIx, BatchProofInputsIx, - BatchedMerkleTreeAccount, BatchedMerkleTreeMetadata, CreateTreeParams, - InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, + get_merkle_tree_account_size, BatchedMerkleTreeAccount, BatchedMerkleTreeMetadata, + CreateTreeParams, InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, }, queue::{ assert_queue_zero_copy_inited, get_output_queue_account_size, BatchedQueueAccount, @@ -764,7 +763,7 @@ pub async fn create_append_batch_ix_data( .unwrap(); InstructionDataBatchAppendInputs { - public_inputs: AppendBatchProofInputsIx { new_root }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -781,7 +780,6 @@ pub async fn create_nullify_batch_ix_data( BatchedMerkleTreeAccount::state_from_bytes(account_data).unwrap(); println!("batches {:?}", zero_copy_account.batches); - let old_root_index = zero_copy_account.root_history.last_index(); let next_full_batch = zero_copy_account .get_metadata() .queue_metadata @@ -820,10 +818,7 @@ pub async fn create_nullify_batch_ix_data( .await .unwrap(); let instruction_data = InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root, - old_root_index: old_root_index as u16, - }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -1609,12 +1604,10 @@ async fn test_batch_address_merkle_trees() { } // 5. Failing: invalid proof // 6. Failing: invalid new root - // 7. Failing: invalid root index - // 8. Failing: update twice with the same instruction (proof and public inputs) + // 7. Failing: update twice with the same instruction (proof and public inputs) for (mode, ix_index) in vec![ UpdateBatchAddressTreeTestMode::InvalidProof, UpdateBatchAddressTreeTestMode::InvalidNewRoot, - UpdateBatchAddressTreeTestMode::InvalidRootIndex, UpdateBatchAddressTreeTestMode::UpdateTwice, ] .iter() @@ -1859,7 +1852,6 @@ pub enum UpdateBatchAddressTreeTestMode { Functional, InvalidProof, InvalidNewRoot, - InvalidRootIndex, UpdateTwice, } @@ -1888,11 +1880,7 @@ pub async fn update_batch_address_tree( BatchedMerkleTreeAccount::address_from_bytes(&mut merkle_tree_account_data).unwrap(); let start_index = zero_copy_account.get_metadata().next_index; - let mut old_root_index = zero_copy_account.root_history.last_index(); - let current_root = zero_copy_account - .root_history - .get(old_root_index as usize) - .unwrap(); + let current_root = zero_copy_account.root_history.last().unwrap(); let next_full_batch = zero_copy_account .get_metadata() .queue_metadata @@ -1923,9 +1911,6 @@ pub async fn update_batch_address_tree( ) .await .unwrap(); - if mode == UpdateBatchAddressTreeTestMode::InvalidRootIndex { - old_root_index -= 1; - } if mode == UpdateBatchAddressTreeTestMode::InvalidNewRoot { new_root[0] = new_root[0].wrapping_add(1); } @@ -1933,10 +1918,7 @@ pub async fn update_batch_address_tree( proof.a = proof.c; } let instruction_data = InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root, - old_root_index: old_root_index as u16, - }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, diff --git a/prover/client/src/mock_batched_forester.rs b/prover/client/src/mock_batched_forester.rs index a5bfcecbcc..318276019b 100644 --- a/prover/client/src/mock_batched_forester.rs +++ b/prover/client/src/mock_batched_forester.rs @@ -348,9 +348,11 @@ impl MockBatchedAddressForester { pub fn finalize_batch_address_update(&mut self, batch_size: usize) { println!("finalize batch address update"); let new_element_values = self.queue_leaves[..batch_size].to_vec(); + println!("removing leaves from queue {}", batch_size); for _ in 0..batch_size { self.queue_leaves.remove(0); } + println!("new queue length {}", self.queue_leaves.len()); for new_element_value in &new_element_values { self.merkle_tree .append( diff --git a/sdk-libs/program-test/src/test_batch_forester.rs b/sdk-libs/program-test/src/test_batch_forester.rs index 27e93f410c..5ef2469721 100644 --- a/sdk-libs/program-test/src/test_batch_forester.rs +++ b/sdk-libs/program-test/src/test_batch_forester.rs @@ -10,9 +10,8 @@ use light_batched_merkle_tree::{ create_output_queue_account, CreateOutputQueueParams, InitStateTreeAccountsInstructionData, }, merkle_tree::{ - get_merkle_tree_account_size, AppendBatchProofInputsIx, BatchProofInputsIx, - BatchedMerkleTreeAccount, BatchedMerkleTreeMetadata, CreateTreeParams, - InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, + get_merkle_tree_account_size, BatchedMerkleTreeAccount, BatchedMerkleTreeMetadata, + CreateTreeParams, InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, }, queue::{ assert_queue_zero_copy_inited, get_output_queue_account_size, BatchedQueueAccount, @@ -210,7 +209,7 @@ pub async fn create_append_batch_ix_data( }; InstructionDataBatchAppendInputs { - public_inputs: AppendBatchProofInputsIx { new_root }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -369,10 +368,7 @@ pub async fn get_batched_nullify_ix_data( }; Ok(InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root, - old_root_index: old_root_index as u16, - }, + new_root, compressed_proof: CompressedProof { a: proof.a, b: proof.b, @@ -804,7 +800,6 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof< let merkle_tree = BatchedMerkleTreeAccount::address_from_bytes(merkle_tree_account.data.as_mut_slice()) .unwrap(); - let old_root_index = merkle_tree.root_history.last_index(); let full_batch_index = merkle_tree.queue_metadata.next_full_batch_index; let batch = &merkle_tree.batches[full_batch_index as usize]; let zkp_batch_index = batch.get_num_inserted_zkps(); @@ -892,10 +887,8 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof< let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); let instruction_data = InstructionDataBatchNullifyInputs { - public_inputs: BatchProofInputsIx { - new_root: circuit_inputs_new_root, - old_root_index: old_root_index as u16, - }, + new_root: circuit_inputs_new_root, + compressed_proof: CompressedProof { a: proof_a, b: proof_b, From 31f66750afdde709ff13a9fdd8c2498bb83a434c Mon Sep 17 00:00:00 2001 From: ananas-block Date: Thu, 16 Jan 2025 02:41:40 +0000 Subject: [PATCH 2/6] fix: lint --- program-libs/batched-merkle-tree/src/merkle_tree.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/program-libs/batched-merkle-tree/src/merkle_tree.rs b/program-libs/batched-merkle-tree/src/merkle_tree.rs index 0576f1ee9a..34a76871d2 100644 --- a/program-libs/batched-merkle-tree/src/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/src/merkle_tree.rs @@ -483,10 +483,10 @@ impl<'a> BatchedMerkleTreeAccount<'a> { /// 3.3. If all zkps are inserted, set the state to inserted. /// 4. Increment next full batch index if inserted. /// 5. Return the batch append event. + /// /// Note: when proving inclusion by index in - /// value array we need to insert the value into a bloom_filter once it is - /// inserted into the tree. Check this with get_num_inserted_zkps - #[cfg(not(target_os = "solana"))] + /// value array we need to insert the value into a bloom_filter once it is + /// inserted into the tree. Check this with get_num_inserted_zkps pub fn update_tree_from_output_queue_account( &mut self, queue_account: &mut BatchedQueueAccount, From bdf58dc2dbc7b92fedfe56733ece35bdde1cb261 Mon Sep 17 00:00:00 2001 From: ananas-block Date: Thu, 16 Jan 2025 20:06:10 +0000 Subject: [PATCH 3/6] chore: refactor and document queue insertion --- program-libs/batched-merkle-tree/src/batch.rs | 318 ++++++++++++------ .../batched-merkle-tree/src/batch_metadata.rs | 8 + .../batched-merkle-tree/src/merkle_tree.rs | 117 ++++--- program-libs/batched-merkle-tree/src/queue.rs | 208 ++++++------ .../src/rollover_address_tree.rs | 2 +- .../batched-merkle-tree/tests/merkle_tree.rs | 79 ++--- .../tests/batched_merkle_tree_test.rs | 2 +- program-tests/system-test/tests/test.rs | 2 +- .../system/src/invoke/verify_state_proof.rs | 2 +- .../src/invoke_cpi/process_cpi_context.rs | 2 +- programs/system/src/sdk/mod.rs | 2 +- sdk-libs/sdk/src/verify.rs | 2 +- 12 files changed, 428 insertions(+), 316 deletions(-) diff --git a/program-libs/batched-merkle-tree/src/batch.rs b/program-libs/batched-merkle-tree/src/batch.rs index b3044f0bed..7f610257c2 100644 --- a/program-libs/batched-merkle-tree/src/batch.rs +++ b/program-libs/batched-merkle-tree/src/batch.rs @@ -1,6 +1,6 @@ use light_bloom_filter::BloomFilter; use light_hasher::{Hasher, Poseidon}; -use light_zero_copy::vec::ZeroCopyVecU64; +use light_zero_copy::{slice_mut::ZeroCopySliceMutU64, vec::ZeroCopyVecU64}; use solana_program::msg; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; @@ -47,16 +47,18 @@ pub struct Batch { /// Theoretical capacity of the bloom_filter. We want to make it much larger /// than batch_size to avoid false positives. pub bloom_filter_capacity: u64, + /// Number of elements in a batch. pub batch_size: u64, + /// Number of elements in a zkp batch. + /// A batch consists out of one or more zkp batches. pub zkp_batch_size: u64, /// Sequence number when it is save to clear the batch without advancing to /// the saved root index. pub sequence_number: u64, + /// Start leaf index of the first pub start_index: u64, pub root_index: u32, - /// Placeholder for forester to signal that the bloom filter is wiped - /// already. - bloom_filter_is_wiped: u8, + bloom_filter_is_zeroed: u8, _padding: [u8; 3], } @@ -80,7 +82,7 @@ impl Batch { sequence_number: 0, root_index: 0, start_index, - bloom_filter_is_wiped: 0, + bloom_filter_is_zeroed: 0, _padding: [0u8; 3], } } @@ -89,16 +91,16 @@ impl Batch { self.state.into() } - pub fn bloom_filter_is_wiped(&self) -> bool { - self.bloom_filter_is_wiped == 1 + pub fn bloom_filter_is_zeroed(&self) -> bool { + self.bloom_filter_is_zeroed == 1 } - pub fn set_bloom_filter_is_wiped(&mut self) { - self.bloom_filter_is_wiped = 1; + pub fn set_bloom_filter_to_zeroed(&mut self) { + self.bloom_filter_is_zeroed = 1; } - pub fn set_bloom_filter_is_not_wiped(&mut self) { - self.bloom_filter_is_wiped = 0; + pub fn set_bloom_filter_to_not_zeroed(&mut self) { + self.bloom_filter_is_zeroed = 0; } /// fill -> full -> inserted -> fill @@ -146,13 +148,26 @@ impl Batch { pub fn get_first_ready_zkp_batch(&self) -> Result { if self.get_state() == BatchState::Inserted { Err(BatchedMerkleTreeError::BatchAlreadyInserted) - } else if self.current_zkp_batch_index > self.num_inserted_zkps { + } else if self.batch_is_ready_to_insert() { Ok(self.num_inserted_zkps) } else { Err(BatchedMerkleTreeError::BatchNotReady) } } + pub fn batch_is_ready_to_insert(&self) -> bool { + self.current_zkp_batch_index > self.num_inserted_zkps + } + + /// Returns the number of zkp batch updates + /// that are ready to be inserted into the tree. + pub fn num_ready_zkp_updates(&self) -> u64 { + self.current_zkp_batch_index + .saturating_sub(self.num_inserted_zkps) + } + + /// Returns the number of inserted elements + /// in the current zkp batch. pub fn get_num_inserted(&self) -> u64 { self.num_inserted } @@ -165,45 +180,76 @@ impl Batch { self.num_inserted_zkps } + /// Returns the number of inserted elements in the batch. pub fn get_num_inserted_elements(&self) -> u64 { self.num_inserted_zkps * self.zkp_batch_size + self.num_inserted } + /// Returns the number of zkp batches in the batch. + pub fn get_num_zkp_batches(&self) -> u64 { + self.batch_size / self.zkp_batch_size + } + + pub fn get_hashchain_store_len(&self) -> usize { + self.batch_size as usize / self.zkp_batch_size as usize + } + + /// Returns the index of a value by leaf index in the value store, + /// provided it could exist in the batch. + pub fn get_value_index_in_batch(&self, leaf_index: u64) -> Result { + self.leaf_index_could_exist_in_batch(leaf_index)?; + leaf_index + .checked_sub(self.start_index) + .ok_or(BatchedMerkleTreeError::LeafIndexNotInBatch) + } + pub fn store_value( &mut self, value: &[u8; 32], value_store: &mut ZeroCopyVecU64<[u8; 32]>, ) -> Result<(), BatchedMerkleTreeError> { - if self.get_state() != BatchState::Fill { - return Err(BatchedMerkleTreeError::BatchNotReady); - } value_store.push(*value)?; Ok(()) } + /// Stores the value in a value store, + /// and adds the value to the current hash chain. pub fn store_and_hash_value( &mut self, value: &[u8; 32], value_store: &mut ZeroCopyVecU64<[u8; 32]>, hashchain_store: &mut ZeroCopyVecU64<[u8; 32]>, ) -> Result<(), BatchedMerkleTreeError> { - self.store_value(value, value_store)?; - self.add_to_hash_chain(value, hashchain_store) + self.add_to_hash_chain(value, hashchain_store)?; + self.store_value(value, value_store) } - /// Inserts into the bloom filter and hashes the value. - /// (used by input/nullifier queue) + /// Inserts into the bloom filter and + /// add value a the current hash chain. + /// (used by nullifier & address queues) pub fn insert( &mut self, bloom_filter_value: &[u8; 32], hashchain_value: &[u8; 32], - store: &mut [u8], + bloom_filter_stores: &mut [ZeroCopySliceMutU64], hashchain_store: &mut ZeroCopyVecU64<[u8; 32]>, + bloom_filter_index: usize, ) -> Result<(), BatchedMerkleTreeError> { - let mut bloom_filter = - BloomFilter::new(self.num_iters as usize, self.bloom_filter_capacity, store)?; - bloom_filter.insert(bloom_filter_value)?; - self.add_to_hash_chain(hashchain_value, hashchain_store) + self.add_to_hash_chain(hashchain_value, hashchain_store)?; + + for (i, bloom_filter) in bloom_filter_stores.iter_mut().enumerate() { + if i == bloom_filter_index { + let mut bloom_filter = BloomFilter::new( + self.num_iters as usize, + self.bloom_filter_capacity, + bloom_filter.as_mut_slice(), + )?; + bloom_filter.insert(bloom_filter_value)?; + } else { + self.check_non_inclusion(bloom_filter_value, bloom_filter.as_mut_slice())?; + } + } + Ok(()) } pub fn add_to_hash_chain( @@ -211,29 +257,36 @@ impl Batch { value: &[u8; 32], hashchain_store: &mut ZeroCopyVecU64<[u8; 32]>, ) -> Result<(), BatchedMerkleTreeError> { - if self.num_inserted == self.zkp_batch_size || self.num_inserted == 0 { + if self.get_state() != BatchState::Fill { + return Err(BatchedMerkleTreeError::BatchNotReady); + } + let start_new_hash_chain = self.num_inserted == 0; + if start_new_hash_chain { hashchain_store.push(*value)?; - self.num_inserted = 0; } else if let Some(last_hashchain) = hashchain_store.last() { let hashchain = Poseidon::hashv(&[last_hashchain, value.as_slice()])?; *hashchain_store.last_mut().unwrap() = hashchain; + } else { + // This state should never be reached. + return Err(BatchedMerkleTreeError::BatchNotReady); } - self.num_inserted += 1; - if self.num_inserted == self.zkp_batch_size { - self.current_zkp_batch_index += 1; - } - if self.get_num_zkp_batches() == self.current_zkp_batch_index { - self.advance_state_to_full()?; + let zkp_batch_is_full = self.num_inserted == self.zkp_batch_size; + if zkp_batch_is_full { + self.current_zkp_batch_index += 1; self.num_inserted = 0; + + let batch_is_full = self.current_zkp_batch_index == self.get_num_zkp_batches(); + if batch_is_full { + self.advance_state_to_full()?; + } } Ok(()) } - /// Inserts into the bloom filter and hashes the value. - /// (used by nullifier queue) + /// Checks that value is not in the bloom filter. pub fn check_non_inclusion( &self, value: &[u8; 32], @@ -249,10 +302,6 @@ impl Batch { Ok(()) } - pub fn get_num_zkp_batches(&self) -> u64 { - self.batch_size / self.zkp_batch_size - } - /// Marks the batch as inserted in the merkle tree. /// 1. Checks that the batch is ready. /// 2. increments the number of inserted zkps. @@ -273,7 +322,8 @@ impl Batch { self.num_inserted_zkps += 1; println!("num_inserted_zkps: {}", self.num_inserted_zkps); // 3. If all zkps are inserted, sets the state to inserted. - if self.num_inserted_zkps == num_zkp_batches { + let batch_is_completly_inserted = self.num_inserted_zkps == num_zkp_batches; + if batch_is_completly_inserted { println!("Setting state to inserted"); self.num_inserted_zkps = 0; self.current_zkp_batch_index = 0; @@ -288,11 +338,11 @@ impl Batch { Ok(self.get_state()) } - pub fn get_hashchain_store_len(&self) -> usize { - self.batch_size as usize / self.zkp_batch_size as usize - } - - pub fn value_is_inserted_in_batch( + /// Returns true if value of leaf index could exist in batch. + /// `True` doesn't mean that the value exists in the batch, + /// just that it is plausible. The value might already be spent + /// or never inserted in case an invalid index was provided. + pub fn leaf_index_could_exist_in_batch( &self, leaf_index: u64, ) -> Result { @@ -301,12 +351,6 @@ impl Batch { let min_batch_leaf_index = self.start_index; Ok(leaf_index < max_batch_leaf_index && leaf_index >= min_batch_leaf_index) } - - pub fn get_value_index_in_batch(&self, leaf_index: u64) -> Result { - leaf_index - .checked_sub(self.start_index) - .ok_or(BatchedMerkleTreeError::LeafIndexNotInBatch) - } } #[cfg(test)] @@ -382,6 +426,7 @@ mod tests { ref_batch.num_inserted += 1; if ref_batch.num_inserted == ref_batch.zkp_batch_size { ref_batch.current_zkp_batch_index += 1; + ref_batch.num_inserted = 0; } if ref_batch.current_zkp_batch_index == ref_batch.get_num_zkp_batches() { ref_batch.state = BatchState::Full.into(); @@ -405,51 +450,97 @@ mod tests { fn test_insert() { // Behavior Input queue let mut batch = get_test_batch(); - let mut store = vec![0u8; 20_000]; - + let mut stores = vec![vec![0u8; 20_008]; 2]; + let mut bloom_filter_stores = stores + .iter_mut() + .map(|store| ZeroCopySliceMutU64::new(20_000, store).unwrap()) + .collect::>(); let mut hashchain_store_bytes = vec![ 0u8; ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity( batch.get_hashchain_store_len() as u64 ) ]; - let mut hashchain_store = ZeroCopyVecU64::<[u8; 32]>::new( + ZeroCopyVecU64::<[u8; 32]>::new( batch.get_hashchain_store_len() as u64, hashchain_store_bytes.as_mut_slice(), ) .unwrap(); let mut ref_batch = get_test_batch(); - for i in 0..batch.batch_size { - ref_batch.num_inserted %= ref_batch.zkp_batch_size; - let mut value = [0u8; 32]; - value[24..].copy_from_slice(&i.to_be_bytes()); - let ref_hash_chain = if i % batch.zkp_batch_size == 0 { - value - } else { - Poseidon::hashv(&[hashchain_store.last().unwrap(), &value]).unwrap() - }; - assert!(batch - .insert(&value, &value, &mut store, &mut hashchain_store) - .is_ok()); - let mut bloom_filter = BloomFilter { - num_iters: batch.num_iters as usize, - capacity: batch.bloom_filter_capacity, - store: &mut store, - }; - assert!(bloom_filter.contains(&value)); - batch.check_non_inclusion(&value, &mut store).unwrap_err(); - - ref_batch.num_inserted += 1; - assert_eq!(*hashchain_store.last().unwrap(), ref_hash_chain); - if ref_batch.num_inserted == ref_batch.zkp_batch_size { - ref_batch.current_zkp_batch_index += 1; - } - if i == batch.batch_size - 1 { - ref_batch.state = BatchState::Full.into(); - ref_batch.num_inserted = 0; + for processing_index in 0..=1 { + for i in 0..(batch.batch_size / 2) { + let i = i + (batch.batch_size / 2) * (processing_index as u64); + + ref_batch.num_inserted %= ref_batch.zkp_batch_size; + let mut hashchain_store = + ZeroCopyVecU64::<[u8; 32]>::from_bytes(hashchain_store_bytes.as_mut_slice()) + .unwrap(); + + let mut value = [0u8; 32]; + value[24..].copy_from_slice(&i.to_be_bytes()); + let ref_hash_chain = if i % batch.zkp_batch_size == 0 { + value + } else { + Poseidon::hashv(&[hashchain_store.last().unwrap(), &value]).unwrap() + }; + let result = batch.insert( + &value, + &value, + bloom_filter_stores.as_mut_slice(), + &mut hashchain_store, + processing_index, + ); + // First insert should succeed + assert!(result.is_ok(), "Failed result: {:?}", result); + assert_eq!(*hashchain_store.last().unwrap(), ref_hash_chain); + + { + let mut cloned_hashchain_store = hashchain_store_bytes.clone(); + let mut hashchain_store = ZeroCopyVecU64::<[u8; 32]>::from_bytes( + cloned_hashchain_store.as_mut_slice(), + ) + .unwrap(); + let mut batch = batch; + // Reinsert should fail + assert!(batch + .insert( + &value, + &value, + bloom_filter_stores.as_mut_slice(), + &mut hashchain_store, + processing_index + ) + .is_err()); + } + let mut bloom_filter = BloomFilter { + num_iters: batch.num_iters as usize, + capacity: batch.bloom_filter_capacity, + store: bloom_filter_stores[processing_index].as_mut_slice(), + }; + assert!(bloom_filter.contains(&value)); + let other_index = if processing_index == 0 { 1 } else { 0 }; + batch + .check_non_inclusion(&value, bloom_filter_stores[other_index].as_mut_slice()) + .unwrap(); + batch + .check_non_inclusion( + &value, + bloom_filter_stores[processing_index].as_mut_slice(), + ) + .unwrap_err(); + + ref_batch.num_inserted += 1; + if ref_batch.num_inserted == ref_batch.zkp_batch_size { + ref_batch.current_zkp_batch_index += 1; + ref_batch.num_inserted = 0; + } + if i == batch.batch_size - 1 { + ref_batch.state = BatchState::Full.into(); + ref_batch.num_inserted = 0; + } + assert_eq!(batch, ref_batch); } - assert_eq!(batch, ref_batch); } test_mark_as_inserted(batch); } @@ -491,29 +582,50 @@ mod tests { #[test] fn test_check_non_inclusion() { - let mut batch = get_test_batch(); - - let value = [1u8; 32]; - let mut store = vec![0u8; 20_000]; - let mut hashchain_store_bytes = vec![ + for processing_index in 0..=1 { + let mut batch = get_test_batch(); + + let value = [1u8; 32]; + let mut stores = vec![vec![0u8; 20_008]; 2]; + let mut bloom_filter_stores = stores + .iter_mut() + .map(|store| ZeroCopySliceMutU64::new(20_000, store).unwrap()) + .collect::>(); + let mut hashchain_store_bytes = vec![ 0u8; ZeroCopyVecU64::<[u8; 32]>::required_size_for_capacity( batch.get_hashchain_store_len() as u64 ) ]; - let mut hashchain_store = ZeroCopyVecU64::<[u8; 32]>::new( - batch.get_hashchain_store_len() as u64, - hashchain_store_bytes.as_mut_slice(), - ) - .unwrap(); - - assert!(batch.check_non_inclusion(&value, &mut store).is_ok()); - let ref_batch = get_test_batch(); - assert_eq!(batch, ref_batch); - batch - .insert(&value, &value, &mut store, &mut hashchain_store) + let mut hashchain_store = ZeroCopyVecU64::<[u8; 32]>::new( + batch.get_hashchain_store_len() as u64, + hashchain_store_bytes.as_mut_slice(), + ) .unwrap(); - assert!(batch.check_non_inclusion(&value, &mut store).is_err()); + + assert!(batch + .check_non_inclusion(&value, bloom_filter_stores[processing_index].as_mut_slice()) + .is_ok()); + let ref_batch = get_test_batch(); + assert_eq!(batch, ref_batch); + batch + .insert( + &value, + &value, + bloom_filter_stores.as_mut_slice(), + &mut hashchain_store, + processing_index, + ) + .unwrap(); + assert!(batch + .check_non_inclusion(&value, bloom_filter_stores[processing_index].as_mut_slice()) + .is_err()); + + let other_index = if processing_index == 0 { 1 } else { 0 }; + assert!(batch + .check_non_inclusion(&value, bloom_filter_stores[other_index].as_mut_slice()) + .is_ok()); + } } #[test] @@ -547,19 +659,19 @@ mod tests { batch.start_index + batch.get_num_zkp_batches() * batch.zkp_batch_size - 1; // 1. Failing test lowest value in eligble range - 1 assert!(!batch - .value_is_inserted_in_batch(lowest_eligible_value - 1) + .leaf_index_could_exist_in_batch(lowest_eligible_value - 1) .unwrap()); // 2. Functional test lowest value in eligble range assert!(batch - .value_is_inserted_in_batch(lowest_eligible_value) + .leaf_index_could_exist_in_batch(lowest_eligible_value) .unwrap()); // 3. Functional test highest value in eligble range assert!(batch - .value_is_inserted_in_batch(highest_eligible_value) + .leaf_index_could_exist_in_batch(highest_eligible_value) .unwrap()); // 4. Failing test eligble range + 1 assert!(!batch - .value_is_inserted_in_batch(highest_eligible_value + 1) + .leaf_index_could_exist_in_batch(highest_eligible_value + 1) .unwrap()); } diff --git a/program-libs/batched-merkle-tree/src/batch_metadata.rs b/program-libs/batched-merkle-tree/src/batch_metadata.rs index ec2f9b0948..19042d5eb9 100644 --- a/program-libs/batched-merkle-tree/src/batch_metadata.rs +++ b/program-libs/batched-merkle-tree/src/batch_metadata.rs @@ -83,6 +83,14 @@ impl BatchMetadata { } } + /// Increment the currently_processing_batch_index if current state is full. + pub fn increment_currently_processing_batch_index_if_full(&mut self, state: BatchState) { + if state == BatchState::Full { + self.currently_processing_batch_index += 1; + self.currently_processing_batch_index %= self.num_batches; + } + } + pub fn init( &mut self, num_batches: u64, diff --git a/program-libs/batched-merkle-tree/src/merkle_tree.rs b/program-libs/batched-merkle-tree/src/merkle_tree.rs index 34a76871d2..ad27fbdb08 100644 --- a/program-libs/batched-merkle-tree/src/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/src/merkle_tree.rs @@ -28,7 +28,7 @@ use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref}; use super::{ batch::Batch, - queue::{init_queue, input_queue_bytes, insert_into_current_batch, queue_account_size}, + queue::{init_queue, input_queue_from_bytes, insert_into_current_batch, queue_account_size}, }; use crate::{ batch::BatchState, @@ -363,7 +363,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { } let (root_history, account_data) = ZeroCopyCyclicVecU64::from_bytes_at(account_data)?; - let (batches, value_vecs, bloom_filter_stores, hashchain_store) = input_queue_bytes( + let (batches, value_vecs, bloom_filter_stores, hashchain_store) = input_queue_from_bytes( &metadata.queue_metadata, account_data, QueueType::BatchedInput as u64, @@ -477,10 +477,10 @@ impl<'a> BatchedMerkleTreeAccount<'a> { /// 2.2. Increment sequence number. /// 2.3. Increment next index. /// 2.4. Append new root to root history. - /// 3. Mark batch as inserted in the merkle tree. + /// 3. Mark zkp batch as inserted in the merkle tree. /// 3.1. Checks that the batch is ready. /// 3.2. Increment the number of inserted zkps. - /// 3.3. If all zkps are inserted, set the state to inserted. + /// 3.3. If all zkps are inserted, set batch state to inserted. /// 4. Increment next full batch index if inserted. /// 5. Return the batch append event. /// @@ -528,9 +528,9 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let root_index = self.root_history.last_index() as u32; - // Update queue metadata. + // Update metadata and batch. { - // 3. Mark batch as inserted in the merkle tree. + // 3. Mark zkp batch as inserted in the merkle tree. let full_batch_state = full_batch.mark_as_inserted_in_merkle_tree( self.metadata.sequence_number, root_index, @@ -590,7 +590,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { /// 3.1. Checks that the batch is ready. /// 3.2. Increment the number of inserted zkps. /// 3.3. If all zkps are inserted, set the state to inserted. - /// 4. Wipe previous batch bloom filter if current batch is 50% inserted. + /// 4. Zero out previous batch bloom filter if current batch is 50% inserted. /// 5. Increment next full batch index if inserted. /// 6. Return the batch nullify event. fn update_input_queue( @@ -648,12 +648,12 @@ impl<'a> BatchedMerkleTreeAccount<'a> { root_history_capacity, )?; - // 4. Wipe previous batch bloom filter + // 4. Zero out previous batch bloom filter // if current batch is 50% inserted. // Needs to be executed prior to // incrementing next full batch index, // but post mark_as_inserted_in_merkle_tree. - self.wipe_previous_batch_bloom_filter()?; + self.zero_out_previous_batch_bloom_filter()?; // 5. Increment next full batch index if inserted. self.metadata @@ -712,11 +712,17 @@ impl<'a> BatchedMerkleTreeAccount<'a> { Ok(()) } - /// State nullification: - /// - value is inserted into to a bloom_filter to prove non-inclusion in later txs. - /// - nullifier is Hash(value, tx_hash), committed to leaves hashchain - /// - tx_hash is hash of all inputs and outputs - /// -> we can access the history of how commitments are spent in zkps for example fraud proofs + /// Insert nullifier into current batch. + /// 1. Check that the tree is a state tree. + /// 2. Create nullifier Hash(value,leaf_index, tx_hash). + /// 3. Insert nullifier into current batch. + /// 3.1. Insert compressed_account_hash into bloom filter. + /// (bloom filter enables non-inclusion proofs in later txs) + /// 3.2. Add nullifier to leaves hash chain. + /// (Nullification means, the compressed_account_hash in the tree, + /// is overwritten with a nullifier hash) + /// 3.3. Check that compressed_account_hash + /// does not exist in any other bloom filter. pub fn insert_nullifier_into_current_batch( &mut self, compressed_account_hash: &[u8; 32], @@ -726,11 +732,19 @@ impl<'a> BatchedMerkleTreeAccount<'a> { // Note, no need to check whether the tree is full // since nullifier insertions update existing values // in the tree and do not append new values. + + // 1. Check that the tree is a state tree. if self.tree_type != TreeType::BatchedState as u64 { return Err(MerkleTreeMetadataError::InvalidTreeType.into()); } - let leaf_index_bytes = leaf_index.to_be_bytes(); - let nullifier = Poseidon::hashv(&[compressed_account_hash, &leaf_index_bytes, tx_hash])?; + + // 2. Create nullifier Hash(value,leaf_index, tx_hash). + let nullifier = { + let leaf_index_bytes = leaf_index.to_be_bytes(); + // Inclusion of the tx_hash enables zk proofs of how a value was spent. + Poseidon::hashv(&[compressed_account_hash, &leaf_index_bytes, tx_hash])? + }; + // 3. Insert nullifier into current batch. self.insert_into_current_batch(compressed_account_hash, &nullifier) } @@ -747,6 +761,10 @@ impl<'a> BatchedMerkleTreeAccount<'a> { self.insert_into_current_batch(address, address) } + /// Insert value into the current batch. + /// 1. Insert value + /// 2. Zero out roots if bloom filter + /// was zeroed out in (insert_into_current_batch). fn insert_into_current_batch( &mut self, bloom_filter_value: &[u8; 32], @@ -759,8 +777,8 @@ impl<'a> BatchedMerkleTreeAccount<'a> { &mut self.value_vecs, &mut self.bloom_filter_stores, &mut self.hashchain_store, - bloom_filter_value, - Some(leaves_hash_value), + leaves_hash_value, + Some(bloom_filter_value), None, )?; @@ -805,8 +823,11 @@ impl<'a> BatchedMerkleTreeAccount<'a> { if let Some(sequence_number) = sequence_number { // If the sequence number is greater than current sequence number // there is still at least one root which can be used to prove - // inclusion of a value which was in the batch that was just wiped. - self.zero_out_roots(sequence_number, root_index); + // inclusion of a value which was in the batch that was just zeroed out. + self.zero_out_roots( + sequence_number, + root_index.ok_or(BatchedMerkleTreeError::InvalidIndex)?, + ); } Ok(()) @@ -817,51 +838,47 @@ impl<'a> BatchedMerkleTreeAccount<'a> { /// 2. If yes: /// 2.1 Get, first safe root index. /// 2.2 Zero out roots from the oldest root to first safe root. - fn zero_out_roots(&mut self, sequence_number: u64, root_index: Option) { + fn zero_out_roots(&mut self, sequence_number: u64, root_index: u32) { // 1. Check whether overlapping roots exist. let overlapping_roots_exits = sequence_number > self.sequence_number; if overlapping_roots_exits { - if let Some(root_index) = root_index { - let root_index = root_index as usize; - - let oldest_root_index = self.root_history.first_index(); - // 2.1. Get, index of first root inserted after input queue batch was inserted. - let first_safe_root_index = self.root_history.len() + root_index; - // 2.2. Zero out roots oldest to first safe root index. - for index in oldest_root_index..first_safe_root_index { - let index = index % self.root_history.len(); - // TODO: test if needed - if index == root_index { - break; - } - self.root_history[index] = [0u8; 32]; + let root_index = root_index as usize; + + let oldest_root_index = self.root_history.first_index(); + // 2.1. Get, index of first root inserted after input queue batch was inserted. + let first_safe_root_index = self.root_history.len() + root_index; + // 2.2. Zero out roots oldest to first safe root index. + for index in oldest_root_index..first_safe_root_index { + let index = index % self.root_history.len(); + // TODO: test if needed + if index == root_index { + break; } - } else { - unreachable!("root_index must be Some(root_index) if overlapping roots exist"); + self.root_history[index] = [0u8; 32]; } } } - /// Wipe bloom filter of previous batch if 50% of the + /// Zero out bloom filter of previous batch if 50% of the /// current batch has been processed. /// /// Idea: - /// 1. Wiping the bloom filter of the previous batch is expensive + /// 1. Zeroing out the bloom filter of the previous batch is expensive /// -> the forester should do it. - /// 2. We don't want to wipe the bloom filter when inserting + /// 2. We don't want to zero out the bloom filter when inserting /// the last zkp of a batch for this might result in failing user tx. /// 3. Wait until next batch is 50% full as grace period for clients /// to switch from proof by index to proof by zkp /// for values inserted in the previous batch. /// /// Steps: - /// 1. Previous batch must be inserted and bloom filter must not be wiped. + /// 1. Previous batch must be inserted and bloom filter must not be zeroed out. /// 2. Current batch must be 50% full /// 3. if yes /// 3.1 zero out bloom filter - /// 3.2 mark bloom filter as wiped + /// 3.2 mark bloom filter as zeroed /// 3.3 zero out roots if needed - pub fn wipe_previous_batch_bloom_filter(&mut self) -> Result<(), BatchedMerkleTreeError> { + pub fn zero_out_previous_batch_bloom_filter(&mut self) -> Result<(), BatchedMerkleTreeError> { let current_batch = self.queue_metadata.next_full_batch_index as usize; let batch_size = self.queue_metadata.batch_size; let previous_full_batch_index = current_batch.saturating_sub(1); @@ -896,7 +913,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; let previous_batch_is_ready = previous_full_batch.get_state() == BatchState::Inserted - && !previous_full_batch.bloom_filter_is_wiped(); + && !previous_full_batch.bloom_filter_is_zeroed(); if previous_batch_is_ready && current_batch_is_half_full { // Keep for finegrained unit test @@ -911,15 +928,15 @@ impl<'a> BatchedMerkleTreeAccount<'a> { .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; bloom_filter.as_mut_slice().iter_mut().for_each(|x| *x = 0); } - // 3.2 Mark bloom filter wiped. - previous_full_batch.set_bloom_filter_is_wiped(); + // 3.2 Mark bloom filter zeroed. + previous_full_batch.set_bloom_filter_to_zeroed(); // 3.3 Zero out roots if a root exists in root history // which allows to prove inclusion of a value - // that was inserted into the bloom filter just wiped. + // that was inserted into the bloom filter just zeroed out. { let seq = previous_full_batch.sequence_number; let root_index = previous_full_batch.root_index; - self.zero_out_roots(seq, Some(root_index)); + self.zero_out_roots(seq, root_index); } } @@ -935,7 +952,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { // TODO: add unit test /// Checks non-inclusion in all bloom filters - /// which are not wiped. + /// which are not zeroed. pub fn check_input_queue_non_inclusion( &mut self, value: &[u8; 32], @@ -944,7 +961,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { for i in 0..num_bloom_filters { let bloom_filter_store = self.bloom_filter_stores[i].as_mut_slice(); let batch = &self.batches[i]; - if !batch.bloom_filter_is_wiped() { + if !batch.bloom_filter_is_zeroed() { batch.check_non_inclusion(value, bloom_filter_store)?; } } diff --git a/program-libs/batched-merkle-tree/src/queue.rs b/program-libs/batched-merkle-tree/src/queue.rs index b9aa36881e..6db2405563 100644 --- a/program-libs/batched-merkle-tree/src/queue.rs +++ b/program-libs/batched-merkle-tree/src/queue.rs @@ -257,7 +257,7 @@ impl<'a> BatchedQueueAccount<'a> { pub fn insert_into_current_batch( &mut self, - value: &[u8; 32], + hash_chain_value: &[u8; 32], ) -> Result<(), BatchedMerkleTreeError> { let current_index = self.next_index; @@ -268,7 +268,7 @@ impl<'a> BatchedQueueAccount<'a> { &mut self.value_vecs, self.bloom_filter_stores.as_mut_slice(), &mut self.hashchain_store, - value, + hash_chain_value, None, Some(current_index), )?; @@ -280,16 +280,16 @@ impl<'a> BatchedQueueAccount<'a> { pub fn prove_inclusion_by_index( &mut self, leaf_index: u64, - value: &[u8; 32], + hash_chain_value: &[u8; 32], ) -> Result { for (batch_index, batch) in self.batches.iter().enumerate() { - if batch.value_is_inserted_in_batch(leaf_index)? { + if batch.leaf_index_could_exist_in_batch(leaf_index)? { let index = batch.get_value_index_in_batch(leaf_index)?; let element = self.value_vecs[batch_index] .get_mut(index as usize) .ok_or(BatchedMerkleTreeError::InclusionProofByIndexFailed)?; - if element == value { + if element == hash_chain_value { return Ok(true); } else { return Err(BatchedMerkleTreeError::InclusionProofByIndexFailed); @@ -304,7 +304,7 @@ impl<'a> BatchedQueueAccount<'a> { leaf_index: u64, ) -> Result<(), BatchedMerkleTreeError> { for batch in self.batches.iter() { - let res = batch.value_is_inserted_in_batch(leaf_index)?; + let res = batch.leaf_index_could_exist_in_batch(leaf_index)?; if res { return Ok(()); } @@ -313,21 +313,21 @@ impl<'a> BatchedQueueAccount<'a> { } // TODO: add unit tests - /// Zero out a leaf by index if it exists in the queues value vec. If + /// Zero out a leaf by index if it exists in the queues hash_chain_value vec. If /// checked fail if leaf is not found. pub fn prove_inclusion_by_index_and_zero_out_leaf( &mut self, leaf_index: u64, - value: &[u8; 32], + hash_chain_value: &[u8; 32], ) -> Result<(), BatchedMerkleTreeError> { for (batch_index, batch) in self.batches.iter().enumerate() { - if batch.value_is_inserted_in_batch(leaf_index)? { + if batch.leaf_index_could_exist_in_batch(leaf_index)? { let index = batch.get_value_index_in_batch(leaf_index)?; let element = self.value_vecs[batch_index] .get_mut(index as usize) .ok_or(BatchedMerkleTreeError::InclusionProofByIndexFailed)?; - if element == value { + if element == hash_chain_value { *element = [0u8; 32]; return Ok(()); } else { @@ -338,18 +338,18 @@ impl<'a> BatchedQueueAccount<'a> { Ok(()) } - pub fn get_batch_num_inserted_in_current_batch(&self) -> u64 { + pub fn get_num_inserted_in_current_batch(&self) -> u64 { let next_full_batch = self.batch_metadata.currently_processing_batch_index; let batch = self.batches.get(next_full_batch as usize).unwrap(); batch.get_num_inserted() + batch.get_current_zkp_batch_index() * batch.zkp_batch_size } - pub fn is_associated(&self, account: &Pubkey) -> bool { - self.metadata.metadata.associated_merkle_tree == *account + pub fn is_associated(&self, pubkey: &Pubkey) -> bool { + self.metadata.metadata.associated_merkle_tree == *pubkey } - pub fn check_is_associated(&self, account: &Pubkey) -> Result<(), BatchedMerkleTreeError> { - if !self.is_associated(account) { + pub fn check_is_associated(&self, pubkey: &Pubkey) -> Result<(), BatchedMerkleTreeError> { + if !self.is_associated(pubkey) { return Err(MerkleTreeMetadataError::MerkleTreeAndQueueNotAssociated.into()); } Ok(()) @@ -367,63 +367,61 @@ impl<'a> BatchedQueueAccount<'a> { } } +/// Insert a value into the current batch. +/// - Input&address queues: Insert into bloom filter & hash chain. +/// - Output queue: Insert into value vec & hash chain. +/// +/// Steps: +/// 1. Check if the current batch is ready. +/// 1.1. If the current batch is inserted, clear the batch. +/// 2. Insert value into the current batch. +/// 3. If batch is full, increment currently_processing_batch_index. #[allow(clippy::too_many_arguments)] #[allow(clippy::type_complexity)] pub fn insert_into_current_batch( queue_type: u64, - account: &mut BatchMetadata, + batch_metadata: &mut BatchMetadata, batches: &mut ZeroCopySliceMutU64, value_vecs: &mut [ZeroCopyVecU64<[u8; 32]>], bloom_filter_stores: &mut [ZeroCopySliceMutU64], hashchain_store: &mut [ZeroCopyVecU64<[u8; 32]>], - value: &[u8; 32], - leaves_hash_value: Option<&[u8; 32]>, + hash_chain_value: &[u8; 32], + bloom_filter_value: Option<&[u8; 32]>, current_index: Option, ) -> Result<(Option, Option), BatchedMerkleTreeError> { - let len = batches.len(); let mut root_index = None; let mut sequence_number = None; - let currently_processing_batch_index = account.currently_processing_batch_index as usize; - // Insert value into current batch. + let batch_index = batch_metadata.currently_processing_batch_index as usize; + let mut value_store = value_vecs.get_mut(batch_index); + let mut hashchain_store = hashchain_store.get_mut(batch_index); + let current_batch = batches + .get_mut(batch_index) + .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; + // 1. Check that the current batch is ready. + // 1.1. If the current batch is inserted, clear the batch. { - let mut bloom_filter_stores = bloom_filter_stores.get_mut(currently_processing_batch_index); - let mut value_store = value_vecs.get_mut(currently_processing_batch_index); - let mut hashchain_store = hashchain_store.get_mut(currently_processing_batch_index); - - let current_batch = batches - .get_mut(currently_processing_batch_index) - .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; - let mut wipe = false; - if current_batch.get_state() == BatchState::Inserted { + let clear_batch = current_batch.get_state() == BatchState::Inserted; + if current_batch.get_state() == BatchState::Fill { + // Do nothing, checking most often case first. + } else if clear_batch { current_batch.advance_state_to_fill()?; - if let Some(current_index) = current_index { - current_batch.start_index = current_index; - } - wipe = true; - } - - // We expect to insert into the current batch. - if current_batch.get_state() == BatchState::Full { - for batch in batches.iter_mut() { - msg!("batch {:?}", batch); - } - return Err(BatchedMerkleTreeError::BatchNotReady); - } - if wipe { - msg!("wipe"); - if let Some(blomfilter_stores) = bloom_filter_stores.as_mut() { - if !current_batch.bloom_filter_is_wiped() { + msg!("clear_batch"); + + if let Some(blomfilter_stores) = bloom_filter_stores.get_mut(batch_index) { + // Bloom filters should by default be zeroed by foresters + // because zeroing bytes is CU intensive. + // This is a safeguard to ensure queue lifeness + // in case foresters are behind. + if !current_batch.bloom_filter_is_zeroed() { (*blomfilter_stores).iter_mut().for_each(|x| *x = 0); // Saving sequence number and root index for the batch. - // When the batch is cleared check that sequence number is greater or equal than self.sequence_number + // When the batch is cleared check that sequence number + // is greater or equal than self.sequence_number // if not advance current root index to root index - if current_batch.sequence_number != 0 { - root_index = Some(current_batch.root_index); - sequence_number = Some(current_batch.sequence_number); - } - } else { - current_batch.set_bloom_filter_is_not_wiped(); + root_index = Some(current_batch.root_index); + sequence_number = Some(current_batch.sequence_number); } + current_batch.set_bloom_filter_to_not_zeroed(); current_batch.sequence_number = 0; } if let Some(value_store) = value_store.as_mut() { @@ -432,41 +430,39 @@ pub fn insert_into_current_batch( if let Some(hashchain_store) = hashchain_store.as_mut() { (*hashchain_store).clear(); } + if let Some(current_index) = current_index { + current_batch.start_index = current_index; + } + } else { + // We expect to insert into the current batch. + for batch in batches.iter_mut() { + msg!("batch {:?}", batch); + } + return Err(BatchedMerkleTreeError::BatchNotReady); } - - let queue_type = QueueType::from(queue_type); - match queue_type { - QueueType::BatchedInput | QueueType::BatchedAddress => current_batch.insert( - value, - leaves_hash_value.unwrap(), - bloom_filter_stores.unwrap().as_mut_slice(), - hashchain_store.as_mut().unwrap(), - ), - QueueType::BatchedOutput => current_batch.store_and_hash_value( - value, - value_store.unwrap(), - hashchain_store.unwrap(), - ), - _ => Err(MerkleTreeMetadataError::InvalidQueueType.into()), - }?; } - // If queue has bloom_filters check non-inclusion of value in bloom_filters of - // other batches. (Current batch is already checked by insertion.) - if !bloom_filter_stores.is_empty() { - for index in currently_processing_batch_index + 1..(len + currently_processing_batch_index) - { - let index = index % len; - let bloom_filter_stores = bloom_filter_stores.get_mut(index).unwrap().as_mut_slice(); - let current_batch = batches.get_mut(index).unwrap(); - current_batch.check_non_inclusion(value, bloom_filter_stores)?; - } - } + // 2. Insert value into the current batch. + let queue_type = QueueType::from(queue_type); + match queue_type { + QueueType::BatchedInput | QueueType::BatchedAddress => current_batch.insert( + bloom_filter_value.unwrap(), + hash_chain_value, + bloom_filter_stores, + hashchain_store.as_mut().unwrap(), + batch_index, + ), + QueueType::BatchedOutput => current_batch.store_and_hash_value( + hash_chain_value, + value_store.unwrap(), + hashchain_store.unwrap(), + ), + _ => Err(MerkleTreeMetadataError::InvalidQueueType.into()), + }?; + + // 3. If batch is full, increment currently_processing_batch_index. + batch_metadata.increment_currently_processing_batch_index_if_full(current_batch.get_state()); - if batches[account.currently_processing_batch_index as usize].get_state() == BatchState::Full { - account.currently_processing_batch_index += 1; - account.currently_processing_batch_index %= len as u64; - } Ok((root_index, sequence_number)) } @@ -496,8 +492,8 @@ pub fn output_queue_from_bytes( } #[allow(clippy::type_complexity)] -pub fn input_queue_bytes<'a>( - account: &BatchMetadata, +pub fn input_queue_from_bytes<'a>( + batch_metadata: &BatchMetadata, account_data: &'a mut [u8], queue_type: u64, ) -> Result< @@ -510,7 +506,7 @@ pub fn input_queue_bytes<'a>( BatchedMerkleTreeError, > { let (num_value_stores, num_stores, hashchain_store_capacity) = - account.get_size_parameters(queue_type)?; + batch_metadata.get_size_parameters(queue_type)?; let (batches, account_data) = ZeroCopySliceMutU64::from_bytes_at(account_data)?; let (value_vecs, account_data) = @@ -526,7 +522,7 @@ pub fn input_queue_bytes<'a>( #[allow(clippy::type_complexity)] pub fn init_queue<'a>( - account: &BatchMetadata, + batch_metadata: &BatchMetadata, queue_type: u64, account_data: &'a mut [u8], num_iters: u64, @@ -542,32 +538,32 @@ pub fn init_queue<'a>( BatchedMerkleTreeError, > { let (num_value_stores, num_stores, num_hashchain_stores) = - account.get_size_parameters(queue_type)?; + batch_metadata.get_size_parameters(queue_type)?; let (mut batches, account_data) = - ZeroCopySliceMutU64::new_at(account.num_batches, account_data)?; + ZeroCopySliceMutU64::new_at(batch_metadata.num_batches, account_data)?; - for i in 0..account.num_batches { + for i in 0..batch_metadata.num_batches { batches[i as usize] = Batch::new( num_iters, bloom_filter_capacity, - account.batch_size, - account.zkp_batch_size, - account.batch_size * i + batch_start_index, + batch_metadata.batch_size, + batch_metadata.zkp_batch_size, + batch_metadata.batch_size * i + batch_start_index, ); } let (value_vecs, account_data) = - ZeroCopyVecU64::new_at_multiple(num_value_stores, account.batch_size, account_data)?; + ZeroCopyVecU64::new_at_multiple(num_value_stores, batch_metadata.batch_size, account_data)?; let (bloom_filter_stores, account_data) = ZeroCopySliceMutU64::new_at_multiple( num_stores, - account.bloom_filter_capacity / 8, + batch_metadata.bloom_filter_capacity / 8, account_data, )?; let (hashchain_store, _) = ZeroCopyVecU64::new_at_multiple( num_hashchain_stores, - account.get_num_zkp_batches(), + batch_metadata.get_num_zkp_batches(), account_data, )?; @@ -575,7 +571,7 @@ pub fn init_queue<'a>( } pub fn get_output_queue_account_size_default() -> usize { - let account = BatchedQueueMetadata { + let batch_metadata = BatchedQueueMetadata { metadata: QueueMetadata::default(), next_index: 0, batch_metadata: BatchMetadata { @@ -586,13 +582,17 @@ pub fn get_output_queue_account_size_default() -> usize { }, ..Default::default() }; - queue_account_size(&account.batch_metadata, QueueType::BatchedOutput as u64).unwrap() + queue_account_size( + &batch_metadata.batch_metadata, + QueueType::BatchedOutput as u64, + ) + .unwrap() } pub fn get_output_queue_account_size_from_params( ix_data: InitStateTreeAccountsInstructionData, ) -> usize { - let account = BatchedQueueMetadata { + let metadata = BatchedQueueMetadata { metadata: QueueMetadata::default(), next_index: 0, batch_metadata: BatchMetadata { @@ -603,7 +603,7 @@ pub fn get_output_queue_account_size_from_params( }, ..Default::default() }; - queue_account_size(&account.batch_metadata, QueueType::BatchedOutput as u64).unwrap() + queue_account_size(&metadata.batch_metadata, QueueType::BatchedOutput as u64).unwrap() } pub fn get_output_queue_account_size( @@ -611,7 +611,7 @@ pub fn get_output_queue_account_size( zkp_batch_size: u64, num_batches: u64, ) -> usize { - let account = BatchedQueueMetadata { + let metadata = BatchedQueueMetadata { metadata: QueueMetadata::default(), next_index: 0, batch_metadata: BatchMetadata { @@ -622,7 +622,7 @@ pub fn get_output_queue_account_size( }, ..Default::default() }; - queue_account_size(&account.batch_metadata, QueueType::BatchedOutput as u64).unwrap() + queue_account_size(&metadata.batch_metadata, QueueType::BatchedOutput as u64).unwrap() } #[allow(clippy::too_many_arguments)] diff --git a/program-libs/batched-merkle-tree/src/rollover_address_tree.rs b/program-libs/batched-merkle-tree/src/rollover_address_tree.rs index 8d4b6d0203..80eabac1a8 100644 --- a/program-libs/batched-merkle-tree/src/rollover_address_tree.rs +++ b/program-libs/batched-merkle-tree/src/rollover_address_tree.rs @@ -25,7 +25,7 @@ pub fn rollover_batched_address_tree<'a>( new_mt_pubkey: Pubkey, network_fee: Option, ) -> Result, BatchedMerkleTreeError> { - // Check that old merkle tree is ready for rollover. + // 1. Check that old merkle tree is ready for rollover. batched_tree_is_ready_for_rollover(old_merkle_tree, &network_fee)?; // Rollover the old merkle tree. diff --git a/program-libs/batched-merkle-tree/tests/merkle_tree.rs b/program-libs/batched-merkle-tree/tests/merkle_tree.rs index b079675d27..6cf6fe257e 100644 --- a/program-libs/batched-merkle-tree/tests/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/tests/merkle_tree.rs @@ -89,7 +89,7 @@ pub fn assert_input_queue_insert( input_is_in_tree: Vec, array_indices: Vec, ) -> Result<(), BatchedMerkleTreeError> { - let mut should_be_wiped = false; + let mut should_be_zeroed = false; for (i, insert_value) in bloom_filter_insert_values.iter().enumerate() { if !input_is_in_tree[i] { let value_vec_index = array_indices[i]; @@ -142,20 +142,20 @@ pub fn assert_input_queue_insert( expected_batch.batch_size / 2 ); - if !should_be_wiped && expected_batch.get_state() == BatchState::Inserted { - should_be_wiped = + if !should_be_zeroed && expected_batch.get_state() == BatchState::Inserted { + should_be_zeroed = expected_batch.get_num_inserted_elements() == expected_batch.batch_size / 2; } println!( - "assert input queue batch update: should_be_wiped: {}", - should_be_wiped + "assert input queue batch update: should_be_zeroed: {}", + should_be_zeroed ); if expected_batch.get_state() == BatchState::Inserted { println!("assert input queue batch update: clearing batch"); pre_hashchains[inserted_batch_index].clear(); expected_batch.sequence_number = 0; expected_batch.advance_state_to_fill().unwrap(); - expected_batch.set_bloom_filter_is_not_wiped(); + expected_batch.set_bloom_filter_to_not_zeroed(); } println!( "assert input queue batch update: inserted_batch_index: {}", @@ -223,8 +223,8 @@ pub fn assert_input_queue_insert( ); let inserted_batch_index = pre_account.queue_metadata.currently_processing_batch_index as usize; let mut expected_batch = pre_batches[inserted_batch_index]; - if should_be_wiped { - expected_batch.set_bloom_filter_is_wiped(); + if should_be_zeroed { + expected_batch.set_bloom_filter_to_zeroed(); } assert_eq!( merkle_tree_account.batches[inserted_batch_index], @@ -1307,19 +1307,19 @@ fn assert_merkle_tree_update( let previous_batch = batches.get_mut(previous_full_batch_index as usize).unwrap(); println!("previous_batch state: {:?}", previous_batch.get_state()); println!( - "previous_batch wiped?: {:?}", - previous_batch.bloom_filter_is_wiped() + "previous_batch zeroed?: {:?}", + previous_batch.bloom_filter_is_zeroed() ); let previous_batch_is_ready = previous_batch.get_state() == BatchState::Inserted - && !previous_batch.bloom_filter_is_wiped(); + && !previous_batch.bloom_filter_is_zeroed(); let batch = batches .get_mut(old_account.queue_metadata.next_full_batch_index as usize) .unwrap(); println!("previous_batch_is_ready: {:?}", previous_batch_is_ready); println!( - "batch.bloom_filter_is_wiped(): {:?}", - batch.bloom_filter_is_wiped() + "batch.bloom_filter_is_zeroed(): {:?}", + batch.bloom_filter_is_zeroed() ); println!( "batch.get_num_inserted_elements(): {:?}", @@ -1327,14 +1327,14 @@ fn assert_merkle_tree_update( ); println!("batch.batch_size: {:?}", batch.batch_size); println!(" batch.get_num_inserted_elements() >= batch.batch_size / 2 && previous_batch_is_ready: {:?}", batch.get_num_inserted_elements()+ batch.zkp_batch_size >= batch.batch_size / 2); - let wiped_batch = batch.get_num_inserted_elements() + batch.zkp_batch_size + let zeroed_batch = batch.get_num_inserted_elements() + batch.zkp_batch_size >= batch.batch_size / 2 && previous_batch_is_ready; let previous_batch = batches.get_mut(previous_full_batch_index as usize).unwrap(); - if wiped_batch { - previous_batch.set_bloom_filter_is_wiped(); - println!("set bloom filter is wiped"); + if zeroed_batch { + previous_batch.set_bloom_filter_to_zeroed(); + println!("set bloom filter is zeroed"); } (account.batches.to_vec(), batches, None, 0) }; @@ -1641,7 +1641,7 @@ async fn test_fill_queues_completely() { vec![], ) .unwrap(); - + println!("leaf {:?}", leaf); // Insert the same value twice { // copy data so that failing test doesn't affect the state of @@ -1701,12 +1701,12 @@ async fn test_fill_queues_completely() { let merkle_tree_account = &mut BatchedMerkleTreeAccount::state_from_bytes(&mut mt_account_data).unwrap(); let batch = merkle_tree_account.batches.get(0).unwrap(); - assert!(batch.bloom_filter_is_wiped()); + assert!(batch.bloom_filter_is_zeroed()); } else { let merkle_tree_account = &mut BatchedMerkleTreeAccount::state_from_bytes(&mut mt_account_data).unwrap(); let batch = merkle_tree_account.batches.get(0).unwrap(); - assert!(!batch.bloom_filter_is_wiped()); + assert!(!batch.bloom_filter_is_zeroed()); } println!( "performed input queue batched update {} created root {:?}", @@ -1734,9 +1734,9 @@ async fn test_fill_queues_completely() { for (i, batch) in merkle_tree_account.batches.iter().enumerate() { assert_eq!(batch.get_state(), BatchState::Inserted); if i == 0 { - assert!(batch.bloom_filter_is_wiped()); + assert!(batch.bloom_filter_is_zeroed()); } else { - assert!(!batch.bloom_filter_is_wiped()); + assert!(!batch.bloom_filter_is_zeroed()); } } } @@ -1930,12 +1930,12 @@ async fn test_fill_address_tree_completely() { BatchedMerkleTreeAccount::address_from_bytes(&mut mt_account_data).unwrap(); let batch = merkle_tree_account.batches.get(0).unwrap(); let batch_one = merkle_tree_account.batches.get(1).unwrap(); - assert!(!batch_one.bloom_filter_is_wiped()); + assert!(!batch_one.bloom_filter_is_zeroed()); if i >= 7 { - assert!(batch.bloom_filter_is_wiped()); + assert!(batch.bloom_filter_is_zeroed()); } else { - assert!(!batch.bloom_filter_is_wiped()); + assert!(!batch.bloom_filter_is_zeroed()); } } // assert all bloom_filters are inserted @@ -1945,43 +1945,18 @@ async fn test_fill_address_tree_completely() { for (i, batch) in merkle_tree_account.batches.iter().enumerate() { assert_eq!(batch.get_state(), BatchState::Inserted); if i == 0 { - assert!(batch.bloom_filter_is_wiped()); + assert!(batch.bloom_filter_is_zeroed()); } else { - assert!(!batch.bloom_filter_is_wiped()); + assert!(!batch.bloom_filter_is_zeroed()); } } } - // do one insert and expect that roots until merkle_tree_account.batches[0].root_index are zero { let merkle_tree_account = &mut BatchedMerkleTreeAccount::address_from_bytes(&mut mt_account_data).unwrap(); println!("root history {:?}", merkle_tree_account.root_history); let pre_batch_zero = *merkle_tree_account.batches.get(0).unwrap(); - // let mut address = get_rnd_bytes(&mut rng); - // address[0] = 0; - // merkle_tree_account.insert_address_into_current_batch(&address); - // { - // let post_batch = merkle_tree_account - // .batches - // .get(0) - // .unwrap() - // .clone(); - // assert_eq!(post_batch.get_state(), BatchState::Fill); - // assert_eq!(post_batch.get_num_inserted(), 1); - // let mut bloom_filter_store = merkle_tree_account - // .bloom_filter_stores - // .get_mut(0) - // .unwrap(); - // let mut bloom_filter = BloomFilter::new( - // params.bloom_filter_num_iters as usize, - // params.bloom_filter_capacity, - // bloom_filter_store.as_mut_slice(), - // ) - // .unwrap(); - // assert!(bloom_filter.contains(&address)); - // } - for root in merkle_tree_account.root_history.iter() { println!("root {:?}", root); } diff --git a/program-tests/account-compression-test/tests/batched_merkle_tree_test.rs b/program-tests/account-compression-test/tests/batched_merkle_tree_test.rs index 66de512773..d0b9caefd5 100644 --- a/program-tests/account-compression-test/tests/batched_merkle_tree_test.rs +++ b/program-tests/account-compression-test/tests/batched_merkle_tree_test.rs @@ -1611,7 +1611,7 @@ async fn test_batch_address_merkle_trees() { UpdateBatchAddressTreeTestMode::UpdateTwice, ] .iter() - .zip(vec![0, 0, 0, 1]) + .zip(vec![0, 0, 1]) { let mut mock_indexer = mock_indexer.clone(); let result = update_batch_address_tree( diff --git a/program-tests/system-test/tests/test.rs b/program-tests/system-test/tests/test.rs index 7dab138760..22d63794f5 100644 --- a/program-tests/system-test/tests/test.rs +++ b/program-tests/system-test/tests/test.rs @@ -2569,7 +2569,7 @@ pub async fn create_compressed_accounts_in_batch_merkle_tree( .unwrap(); let output_queue = BatchedQueueAccount::output_from_bytes(&mut output_queue_account.data).unwrap(); - let fullness = output_queue.get_batch_num_inserted_in_current_batch(); + let fullness = output_queue.get_num_inserted_in_current_batch(); let remaining_leaves = output_queue.get_metadata().batch_metadata.batch_size - fullness; for _ in 0..remaining_leaves { create_output_accounts(context, &payer, test_indexer, output_queue_pubkey, 1, true).await?; diff --git a/programs/system/src/invoke/verify_state_proof.rs b/programs/system/src/invoke/verify_state_proof.rs index 1d3f0d4abd..60f8d3d063 100644 --- a/programs/system/src/invoke/verify_state_proof.rs +++ b/programs/system/src/invoke/verify_state_proof.rs @@ -245,7 +245,7 @@ fn fetch_root( /// 1. prove inclusion by index in the output queue if leaf index should exist in the output queue. /// 1.1. if inclusion was proven by index, return Ok. /// 2. prove non-inclusion in the bloom filters -/// 2.1. skip wiped batches. +/// 2.1. skip cleared batches. /// 2.2. prove non-inclusion in the bloom filters for each batch. #[inline(always)] pub fn verify_read_only_account_inclusion_by_index<'a>( diff --git a/programs/system/src/invoke_cpi/process_cpi_context.rs b/programs/system/src/invoke_cpi/process_cpi_context.rs index 5f7490b45a..e5978f94ee 100644 --- a/programs/system/src/invoke_cpi/process_cpi_context.rs +++ b/programs/system/src/invoke_cpi/process_cpi_context.rs @@ -94,7 +94,7 @@ pub fn set_cpi_context( ) -> Result<()> { // SAFETY Assumptions: // - previous data in cpi_context_account - // -> we require the account to be wiped in the beginning of a + // -> we require the account to be cleared in the beginning of a // transaction // - leaf over data: There cannot be any leftover data in the // account since if the transaction fails the account doesn't change. diff --git a/programs/system/src/sdk/mod.rs b/programs/system/src/sdk/mod.rs index f23b96b69e..b9fe794ea8 100644 --- a/programs/system/src/sdk/mod.rs +++ b/programs/system/src/sdk/mod.rs @@ -10,7 +10,7 @@ pub struct CompressedCpiContext { /// Is set by the program that is invoking the CPI to signal that is should /// set the cpi context. pub set_context: bool, - /// Is set to wipe the cpi context since someone could have set it before + /// Is set to clear the cpi context since someone could have set it before /// with unrelated data. pub first_set_context: bool, /// Index of cpi context account in remaining accounts. diff --git a/sdk-libs/sdk/src/verify.rs b/sdk-libs/sdk/src/verify.rs index 7ddad050b8..c6ad61e8dc 100644 --- a/sdk-libs/sdk/src/verify.rs +++ b/sdk-libs/sdk/src/verify.rs @@ -26,7 +26,7 @@ pub struct CompressedCpiContext { /// Is set by the program that is invoking the CPI to signal that is should /// set the cpi context. pub set_context: bool, - /// Is set to wipe the cpi context since someone could have set it before + /// Is set to clear the cpi context since someone could have set it before /// with unrelated data. pub first_set_context: bool, /// Index of cpi context account in remaining accounts. From 422e7508c6b25cd4ac7f31338c80947b3130d355 Mon Sep 17 00:00:00 2001 From: ananas-block Date: Thu, 16 Jan 2025 22:23:34 +0000 Subject: [PATCH 4/6] refactor: zero out previous batch bloom filter --- program-libs/batched-merkle-tree/src/batch.rs | 5 ++ .../batched-merkle-tree/src/batch_metadata.rs | 22 ++++++-- .../batched-merkle-tree/src/merkle_tree.rs | 54 +++++++++++++------ 3 files changed, 62 insertions(+), 19 deletions(-) diff --git a/program-libs/batched-merkle-tree/src/batch.rs b/program-libs/batched-merkle-tree/src/batch.rs index 7f610257c2..ddd968c5f7 100644 --- a/program-libs/batched-merkle-tree/src/batch.rs +++ b/program-libs/batched-merkle-tree/src/batch.rs @@ -332,6 +332,11 @@ impl Batch { // When the batch is cleared check that sequence number is greater or equal than self.sequence_number // if not advance current root index to root index self.sequence_number = sequence_number + root_history_length as u64; + println!("root_history_length as u64: {}", root_history_length as u64); + println!("sequence_number: {}", sequence_number); + println!("recorded sequence_number: {}", self.sequence_number); + println!("current root index {}", root_index); + self.root_index = root_index; } diff --git a/program-libs/batched-merkle-tree/src/batch_metadata.rs b/program-libs/batched-merkle-tree/src/batch_metadata.rs index 19042d5eb9..153ce37e41 100644 --- a/program-libs/batched-merkle-tree/src/batch_metadata.rs +++ b/program-libs/batched-merkle-tree/src/batch_metadata.rs @@ -75,7 +75,7 @@ impl BatchMetadata { }) } - /// Increment the next full batch index if current state is inserted. + /// Increment the next full batch index if current state is BatchState::Inserted. pub fn increment_next_full_batch_index_if_inserted(&mut self, state: BatchState) { if state == BatchState::Inserted { self.next_full_batch_index += 1; @@ -83,7 +83,7 @@ impl BatchMetadata { } } - /// Increment the currently_processing_batch_index if current state is full. + /// Increment the currently_processing_batch_index if current state is BatchState::Full. pub fn increment_currently_processing_batch_index_if_full(&mut self, state: BatchState) { if state == BatchState::Full { self.currently_processing_batch_index += 1; @@ -134,7 +134,6 @@ impl BatchMetadata { #[test] fn test_increment_next_full_batch_index_if_inserted() { - // create a new metadata struct let mut metadata = BatchMetadata::new_input_queue(10, 10, 10, 2).unwrap(); assert_eq!(metadata.next_full_batch_index, 0); // increment next full batch index @@ -149,3 +148,20 @@ fn test_increment_next_full_batch_index_if_inserted() { metadata.increment_next_full_batch_index_if_inserted(BatchState::Full); assert_eq!(metadata.next_full_batch_index, 0); } + +#[test] +fn test_increment_currently_processing_batch_index_if_full() { + let mut metadata = BatchMetadata::new_input_queue(10, 10, 10, 2).unwrap(); + assert_eq!(metadata.currently_processing_batch_index, 0); + // increment currently_processing_batch_index + metadata.increment_currently_processing_batch_index_if_full(BatchState::Full); + assert_eq!(metadata.currently_processing_batch_index, 1); + // increment currently_processing_batch_index + metadata.increment_currently_processing_batch_index_if_full(BatchState::Full); + assert_eq!(metadata.currently_processing_batch_index, 0); + // try incrementing next full batch index with state not full + metadata.increment_currently_processing_batch_index_if_full(BatchState::Fill); + assert_eq!(metadata.currently_processing_batch_index, 0); + metadata.increment_currently_processing_batch_index_if_full(BatchState::Inserted); + assert_eq!(metadata.currently_processing_batch_index, 0); +} diff --git a/program-libs/batched-merkle-tree/src/merkle_tree.rs b/program-libs/batched-merkle-tree/src/merkle_tree.rs index ad27fbdb08..948caec2de 100644 --- a/program-libs/batched-merkle-tree/src/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/src/merkle_tree.rs @@ -424,6 +424,10 @@ impl<'a> BatchedMerkleTreeAccount<'a> { account_metadata.root_history_capacity as u64, account_data, )?; + // fill root history with zero bytes + for _ in 0..root_history.capacity() { + root_history.push([0u8; 32]); + } if tree_type == TreeType::BatchedState { root_history.push(light_hasher::Poseidon::zero_bytes()[height as usize]); } else if tree_type == TreeType::BatchedAddress { @@ -432,6 +436,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { // The initialized indexed Merkle tree contains two elements. account_metadata.next_index = 2; } + let (batches, value_vecs, bloom_filter_stores, hashchain_store) = init_queue( &account_metadata.queue_metadata, QueueType::BatchedInput as u64, @@ -534,7 +539,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let full_batch_state = full_batch.mark_as_inserted_in_merkle_tree( self.metadata.sequence_number, root_index, - self.metadata.root_history_capacity, + self.root_history_capacity, )?; // 4. Increment next full batch index if inserted. queue_account @@ -838,24 +843,40 @@ impl<'a> BatchedMerkleTreeAccount<'a> { /// 2. If yes: /// 2.1 Get, first safe root index. /// 2.2 Zero out roots from the oldest root to first safe root. - fn zero_out_roots(&mut self, sequence_number: u64, root_index: u32) { + fn zero_out_roots(&mut self, sequence_number: u64, first_safe_root_index: u32) { // 1. Check whether overlapping roots exist. let overlapping_roots_exits = sequence_number > self.sequence_number; if overlapping_roots_exits { - let root_index = root_index as usize; - - let oldest_root_index = self.root_history.first_index(); - // 2.1. Get, index of first root inserted after input queue batch was inserted. - let first_safe_root_index = self.root_history.len() + root_index; + // let root_index = root_index as usize; + + let mut oldest_root_index = self.root_history.first_index(); + // 2.1. Get, num of remaining roots. + // Remaining roots have not been updated since + // the update of the previous batch hence enable to prove + // inclusion of values nullified in the previous batch. + let num_remaining_roots = sequence_number - self.sequence_number; + println!("sequence_number: {}", sequence_number); + println!("self.sequence_number: {}", self.sequence_number); + println!("oldest_root_index: {}", oldest_root_index); + println!("first_safe_root_index: {}", first_safe_root_index); + println!("num_remaining_roots: {}", num_remaining_roots); + println!( + "self.root_history.len() as u64: {}", + self.root_history.len() as u64 + ); // 2.2. Zero out roots oldest to first safe root index. - for index in oldest_root_index..first_safe_root_index { - let index = index % self.root_history.len(); - // TODO: test if needed - if index == root_index { - break; - } - self.root_history[index] = [0u8; 32]; + // Skip one iteration we don't need to zero out + // the first safe root. + for _ in 1..num_remaining_roots { + println!("zeroing out root index: {}", oldest_root_index); + self.root_history[oldest_root_index] = [0u8; 32]; + oldest_root_index += 1; + oldest_root_index %= self.root_history.len(); } + assert_eq!( + oldest_root_index as u32, first_safe_root_index, + "Zeroing out roots failed." + ); } } @@ -912,8 +933,9 @@ impl<'a> BatchedMerkleTreeAccount<'a> { .get_mut(previous_full_batch_index) .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; - let previous_batch_is_ready = previous_full_batch.get_state() == BatchState::Inserted - && !previous_full_batch.bloom_filter_is_zeroed(); + let batch_is_inserted = previous_full_batch.get_state() == BatchState::Inserted; + let previous_batch_is_ready = + batch_is_inserted && !previous_full_batch.bloom_filter_is_zeroed(); if previous_batch_is_ready && current_batch_is_half_full { // Keep for finegrained unit test From 76432f318f511efab807982ca002be0bb6b14dc7 Mon Sep 17 00:00:00 2001 From: ananas-block Date: Thu, 16 Jan 2025 22:29:42 +0000 Subject: [PATCH 5/6] refactor: batch metadata checks into batch_size_validation --- .../batched-merkle-tree/src/batch_metadata.rs | 45 +++++++++++++------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/program-libs/batched-merkle-tree/src/batch_metadata.rs b/program-libs/batched-merkle-tree/src/batch_metadata.rs index 153ce37e41..d86540fa06 100644 --- a/program-libs/batched-merkle-tree/src/batch_metadata.rs +++ b/program-libs/batched-merkle-tree/src/batch_metadata.rs @@ -34,18 +34,28 @@ pub struct BatchMetadata { } impl BatchMetadata { + /// Returns the number of ZKP batches contained within a single regular batch. pub fn get_num_zkp_batches(&self) -> u64 { self.batch_size / self.zkp_batch_size } - pub fn new_output_queue( + /// Validates that the batch size is properly divisible by the ZKP batch size. + fn validate_batch_sizes( batch_size: u64, zkp_batch_size: u64, - num_batches: u64, - ) -> Result { + ) -> Result<(), BatchedMerkleTreeError> { if batch_size % zkp_batch_size != 0 { return Err(BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize); } + Ok(()) + } + + pub fn new_output_queue( + batch_size: u64, + zkp_batch_size: u64, + num_batches: u64, + ) -> Result { + Self::validate_batch_sizes(batch_size, zkp_batch_size)?; Ok(BatchMetadata { num_batches, zkp_batch_size, @@ -62,9 +72,8 @@ impl BatchMetadata { zkp_batch_size: u64, num_batches: u64, ) -> Result { - if batch_size % zkp_batch_size != 0 { - return Err(BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize); - } + Self::validate_batch_sizes(batch_size, zkp_batch_size)?; + Ok(BatchMetadata { num_batches, zkp_batch_size, @@ -78,16 +87,15 @@ impl BatchMetadata { /// Increment the next full batch index if current state is BatchState::Inserted. pub fn increment_next_full_batch_index_if_inserted(&mut self, state: BatchState) { if state == BatchState::Inserted { - self.next_full_batch_index += 1; - self.next_full_batch_index %= self.num_batches; + self.next_full_batch_index = (self.next_full_batch_index + 1) % self.num_batches; } } /// Increment the currently_processing_batch_index if current state is BatchState::Full. pub fn increment_currently_processing_batch_index_if_full(&mut self, state: BatchState) { if state == BatchState::Full { - self.currently_processing_batch_index += 1; - self.currently_processing_batch_index %= self.num_batches; + self.currently_processing_batch_index = + (self.currently_processing_batch_index + 1) % self.num_batches; } } @@ -97,12 +105,10 @@ impl BatchMetadata { batch_size: u64, zkp_batch_size: u64, ) -> Result<(), BatchedMerkleTreeError> { + // Check that batch size is divisible by zkp_batch_size. + Self::validate_batch_sizes(batch_size, zkp_batch_size)?; self.num_batches = num_batches; self.batch_size = batch_size; - // Check that batch size is divisible by zkp_batch_size. - if batch_size % zkp_batch_size != 0 { - return Err(BatchedMerkleTreeError::BatchSizeNotDivisibleByZkpBatchSize); - } self.zkp_batch_size = zkp_batch_size; Ok(()) } @@ -165,3 +171,14 @@ fn test_increment_currently_processing_batch_index_if_full() { metadata.increment_currently_processing_batch_index_if_full(BatchState::Inserted); assert_eq!(metadata.currently_processing_batch_index, 0); } + +#[test] +fn test_batch_size_validation() { + // Test invalid batch size + assert!(BatchMetadata::new_input_queue(10, 10, 3, 2).is_err()); + assert!(BatchMetadata::new_output_queue(10, 3, 2).is_err()); + + // Test valid batch size + assert!(BatchMetadata::new_input_queue(9, 10, 3, 2).is_ok()); + assert!(BatchMetadata::new_output_queue(9, 3, 2).is_ok()); +} From b28fcdc4ffe74b77e9b3b6056038cc23c5bb2c58 Mon Sep 17 00:00:00 2001 From: ananas-block Date: Thu, 16 Jan 2025 23:20:57 +0000 Subject: [PATCH 6/6] chore: fix forester tests --- forester/tests/batched_address_test.rs | 7 +++++-- forester/tests/batched_state_test.rs | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/forester/tests/batched_address_test.rs b/forester/tests/batched_address_test.rs index 03c18eb481..8b9f27f062 100644 --- a/forester/tests/batched_address_test.rs +++ b/forester/tests/batched_address_test.rs @@ -293,11 +293,14 @@ async fn test_address_batched() { let expected_sequence_number = initial_sequence_number + (num_zkp_batches * UPDATES_PER_BATCH); - let expected_root_history_len = (expected_sequence_number + 1) as usize; + let expected_root_history_len = expected_sequence_number as usize; assert_eq!(final_metadata.sequence_number, expected_sequence_number); - assert_eq!(merkle_tree.root_history.len(), expected_root_history_len); + assert_eq!( + merkle_tree.root_history.last_index(), + expected_root_history_len + ); assert_ne!( pre_root, diff --git a/forester/tests/batched_state_test.rs b/forester/tests/batched_state_test.rs index d3e07275c6..6d781910f4 100644 --- a/forester/tests/batched_state_test.rs +++ b/forester/tests/batched_state_test.rs @@ -337,11 +337,13 @@ async fn test_state_batched() { let expected_sequence_number = initial_sequence_number + (num_zkp_batches * UPDATES_PER_BATCH); - let expected_root_history_len = (expected_sequence_number + 1) as usize; assert_eq!(final_metadata.sequence_number, expected_sequence_number); - assert_eq!(merkle_tree.root_history.len(), expected_root_history_len); + assert_eq!( + merkle_tree.root_history.last_index(), + expected_sequence_number as usize + ); assert_ne!( pre_root,