Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for hashmaps in Smt and SimpleSmt #363

Merged
merged 17 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
- [BREAKING] Updated Winterfell dependency to v0.11 (#346).
- Added support for hashmaps in `Smt` and `SimpleSmt` which gives up to 10x boost in some operations (#363).


## 0.12.0 (2024-10-30)
Expand Down
31 changes: 31 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ harness = false
concurrent = ["dep:rayon"]
default = ["std", "concurrent"]
executable = ["dep:clap", "dep:rand-utils", "std"]
smt_hashmaps = []
bobbinth marked this conversation as resolved.
Show resolved Hide resolved
internal = []
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [
Expand All @@ -63,6 +64,7 @@ std = [
[dependencies]
blake3 = { version = "1.5", default-features = false }
clap = { version = "4.5", optional = true, features = ["derive"] }
hashbrown = { version = "0.15", features = ["serde"] }
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
num-complex = { version = "0.4", default-features = false }
rand = { version = "0.8", default-features = false }
Expand Down
14 changes: 13 additions & 1 deletion src/hash/rescue/rpo/digest.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use core::{
cmp::Ordering,
fmt::Display,
hash::{Hash, Hasher},
ops::Deref,
slice,
};

use thiserror::Error;

Expand Down Expand Up @@ -55,6 +61,12 @@ impl RpoDigest {
}
}

impl Hash for RpoDigest {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(&self.as_bytes());
}
}

impl Digest for RpoDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];
Expand Down
1 change: 1 addition & 0 deletions src/merkle/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::RpoDigest;
/// Representation of a node with two children used for iterating over containers.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(test, derive(PartialOrd, Ord))]
pub struct InnerNodeInfo {
pub value: RpoDigest,
pub left: RpoDigest,
Expand Down
30 changes: 12 additions & 18 deletions src/merkle/smt/full/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
string::ToString,
vec::Vec,
};
use alloc::{collections::BTreeSet, string::ToString, vec::Vec};

use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};

mod error;
Expand All @@ -30,6 +26,8 @@ pub const SMT_DEPTH: u8 = 64;
// SMT
// ================================================================================================

type Leaves = super::Leaves<SmtLeaf>;

/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
/// by 4 field elements.
///
Expand All @@ -43,8 +41,8 @@ pub const SMT_DEPTH: u8 = 64;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: RpoDigest,
leaves: BTreeMap<u64, SmtLeaf>,
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
inner_nodes: InnerNodes,
leaves: Leaves,
}

impl Smt {
Expand All @@ -64,8 +62,8 @@ impl Smt {

Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
inner_nodes: Default::default(),
leaves: Default::default(),
}
}

Expand Down Expand Up @@ -148,11 +146,7 @@ impl Smt {
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
root: RpoDigest,
) -> Self {
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
Expand Down Expand Up @@ -339,8 +333,8 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);

fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
inner_nodes: InnerNodes,
leaves: Leaves,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {
Expand Down
30 changes: 15 additions & 15 deletions src/merkle/smt/full/tests.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use alloc::{collections::BTreeMap, vec::Vec};
use alloc::vec::Vec;

use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{
merkle::{
smt::{NodeMutation, SparseMerkleTree},
smt::{types::UnorderedMap, NodeMutation, SparseMerkleTree},
EmptySubtreeRoots, MerkleStore, MutationSet,
},
utils::{Deserializable, Serializable},
Expand Down Expand Up @@ -420,7 +420,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, EMPTY_WORD)]),
UnorderedMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
assert_eq!(
Expand All @@ -440,7 +440,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
UnorderedMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);

Expand All @@ -454,7 +454,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_3, value_3)]),
UnorderedMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match"
);

Expand All @@ -474,7 +474,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
UnorderedMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match"
);

Expand Down Expand Up @@ -603,21 +603,21 @@ fn test_smt_get_value() {
/// Tests that `entries()` works as expected
#[test]
fn test_smt_entries() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
let key_1 = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2 = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);

let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let entries = [(key_1, value_1), (key_2, value_2)];

let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
let smt = Smt::with_entries(entries).unwrap();

let mut entries = smt.entries();
let mut expected = Vec::from_iter(entries);
expected.sort_by_key(|(k, _)| *k);
let mut actual: Vec<_> = smt.entries().cloned().collect();
actual.sort_by_key(|(k, _)| *k);

// Note: for simplicity, we assume the order `(k1,v1), (k2,v2)`. If a new implementation
// switches the order, it is OK to modify the order here as well.
assert_eq!(&(key_1, value_1), entries.next().unwrap());
assert_eq!(&(key_2, value_2), entries.next().unwrap());
assert!(entries.next().is_none());
assert_eq!(actual, expected);
}

/// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of
Expand Down
Loading
Loading