diff --git a/readme.md b/readme.md index d47ea9f..b7ee180 100644 --- a/readme.md +++ b/readme.md @@ -67,7 +67,7 @@ let mut leaves = Vec::new(); for i in 0..num_leaves { leaves.push(i.to_string()); } -let tree = StandardMerkleTree::of(leaves.clone()); +let tree = StandardMerkleTree::of_sorted(leaves.clone()); for leaf in leaves.iter() { let proof = tree.get_proof(leaf); diff --git a/src/standard_binary_tree.rs b/src/standard_binary_tree.rs index 2a7f266..50e5128 100644 --- a/src/standard_binary_tree.rs +++ b/src/standard_binary_tree.rs @@ -13,7 +13,7 @@ //! for i in 0..num_leaves { //! leaves.push(DynSolValue::String(i.to_string())); //! } -//! let tree = StandardMerkleTree::of(&leaves); +//! let tree = StandardMerkleTree::of_sorted(&leaves); //! //! for leaf in leaves.iter() { //! let proof = tree.get_proof(leaf).unwrap(); @@ -73,15 +73,27 @@ impl StandardMerkleTree { Self { tree, tree_values } } - /// Constructs a [`StandardMerkleTree`] from a slice of dynamic Solidity values. pub fn of(values: &[DynSolValue]) -> Self { + Self::create(values, false) + } + + pub fn of_sorted(values: &[DynSolValue]) -> Self { + Self::create(values, true) + } + + /// Constructs a [`StandardMerkleTree`] from a slice of dynamic Solidity values. + fn create(values: &[DynSolValue], sort_leaves: bool) -> Self { // Hash each value and associate it with its index and leaf hash. - let hashed_values: Vec<(&DynSolValue, usize, B256)> = values + let mut hashed_values: Vec<(&DynSolValue, usize, B256)> = values .iter() .enumerate() .map(|(i, value)| (value, i, standard_leaf_hash(value))) .collect(); + // Sort the hashed values by their hash. + if sort_leaves { + hashed_values.sort_by(|(_, _, a), (_, _, b)| a.cmp(b)); + } // Collect the leaf hashes into a vector. let hashed_values_hash = hashed_values .iter() @@ -265,9 +277,11 @@ fn hash_pair(left: B256, right: B256) -> B256 { mod test { use crate::alloc::string::ToString; use crate::standard_binary_tree::StandardMerkleTree; + use alloc::string::String; + use alloc::vec; use alloc::vec::Vec; use alloy::dyn_abi::DynSolValue; - use alloy::primitives::{hex::FromHex, FixedBytes}; + use alloy::primitives::{address, hex, hex::FromHex, FixedBytes, U256}; /// Tests the [`StandardMerkleTree`] with string-type leaves. #[test] @@ -309,4 +323,66 @@ mod test { assert!(is_valid); } } + + /// Tests the [`StandardMerkleTree`] with a tuple leaves of hardhat addresses and amounts. + /// Equivalent to JS: const tree = StandardMerkleTree.of(values, ["address", "uint256"]); + #[test] + fn test_hardhat_tuples() { + let mut leaves = Vec::new(); + + vec![ + ( + address!("f39Fd6e51aad88F6F4ce6aB8827279cffFb92266"), + U256::from(10000), + ), + ( + address!("70997970C51812dc3A010C7d01b50e0d17dc79C8"), + U256::from(1000), + ), + ( + address!("3c44cdddb6a900fa2b585dd299e03d12fa4293bc"), + U256::from(100), + ), + ( + address!("90f79bf6eb2c4f870365e785982e1f101e93b906"), + U256::from(10), + ), + ( + address!("15d34aaf54267db7d7c367839aaf71a00a2c6a65"), + U256::from(1), + ), + ] + .iter() + .for_each(|(address, amount)| { + leaves.push(unsafe { + DynSolValue::String(String::from_utf8_unchecked( + DynSolValue::Tuple(vec![ + DynSolValue::Address(*address), + DynSolValue::Uint(*amount, 256), + ]) + .abi_encode(), + )) + }); + }); + + let tree = StandardMerkleTree::of_sorted(&leaves); + + let proof = tree.get_proof(leaves.first().unwrap()).unwrap(); + let is_valid = tree.verify_proof(leaves.first().unwrap(), proof.clone()); + assert!(is_valid); + assert_eq!( + proof, + vec![ + hex!("8ee56d16226ff6684927054c33cd505c4eee1ebabbffe198460d00cb083aaebd"), + hex!("fa31eb8d65ff2307b7026df667a06a19aade0151ed701ed2307295ae4fa48364"), + hex!("f0768f444c5a27a6bb7c9203b0b5b147e501ff7b7784e0363e5751590962b034"), + ] + ); + + let root = tree.root(); + assert_eq!( + root, + hex!("2b4b963c699c531f94ca8f8a0ef76c5d28f067d79927c035a44296190c2d8029") + ); + } }