Skip to content

Commit

Permalink
mmr: added partial mmr
Browse files Browse the repository at this point in the history
  • Loading branch information
hackaugusto committed Oct 18, 2023
1 parent 78aa714 commit 67c7422
Show file tree
Hide file tree
Showing 9 changed files with 983 additions and 91 deletions.
33 changes: 27 additions & 6 deletions src/merkle/mmr/accumulator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
super::{RpoDigest, Vec, ZERO},
Felt, MmrProof, Rpo256, Word,
Felt, MmrError, MmrProof, Rpo256, Word,
};

#[derive(Debug, Clone, PartialEq)]
Expand All @@ -9,9 +9,9 @@ pub struct MmrPeaks {
/// The number of leaves is used to differentiate accumulators that have the same number of
/// peaks. This happens because the number of peaks goes up-and-down as the structure is used
/// causing existing trees to be merged and new ones to be created. As an example, every time
/// the MMR has a power-of-two number of leaves there is a single peak.
/// the [Mmr] has a power-of-two number of leaves there is a single peak.
///
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right
/// Every tree in the [Mmr] forest has a distinct power-of-two size, this means only the right
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the bits in
/// `num_leaves` conveniently encode the size of each individual tree.
///
Expand All @@ -23,16 +23,37 @@ pub struct MmrPeaks {
/// elements and the left most has `2**2`.
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
pub num_leaves: usize,
num_leaves: usize,

/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
/// All the peaks of every tree in the [Mmr] forest. The peaks are always ordered by number of
/// leaves, starting from the peak with most children, to the one with least.
///
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
pub peaks: Vec<RpoDigest>,
peaks: Vec<RpoDigest>,
}

impl MmrPeaks {
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
if num_leaves.count_ones() as usize != peaks.len() {
return Err(MmrError::InvalidPeaks);
}

Ok(Self { num_leaves, peaks })
}

// ACESSORS
// --------------------------------------------------------------------------------------------

/// Returns a count of the [Mmr]'s leaves.
pub fn num_leaves(&self) -> usize {
self.num_leaves
}

/// Returns the current peaks of the [Mmr].
pub fn peaks(&self) -> &Vec<RpoDigest> {
&self.peaks
}

/// Hashes the peaks.
///
/// The procedure will:
Expand Down
35 changes: 35 additions & 0 deletions src/merkle/mmr/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use crate::merkle::MerkleError;
use core::fmt::{Display, Formatter};

#[cfg(feature = "std")]
use std::error::Error;

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum MmrError {
InvalidPosition(usize),
InvalidPeaks,
InvalidPeak,
InvalidUpdate,
UnknownPeak,
MerkleError(MerkleError),
}

impl Display for MmrError {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
match self {
MmrError::InvalidPosition(pos) => write!(fmt, "Mmr does not contain position {pos}"),
MmrError::InvalidPeaks => write!(fmt, "Invalid peaks count"),
MmrError::InvalidPeak => {
write!(fmt, "Peak values does not match merkle path computed root")
}
MmrError::InvalidUpdate => write!(fmt, "Invalid mmr update"),
MmrError::UnknownPeak => {
write!(fmt, "Peak not in Mmr")
}
MmrError::MerkleError(err) => write!(fmt, "{}", err),
}
}
}

#[cfg(feature = "std")]
impl Error for MmrError {}
158 changes: 110 additions & 48 deletions src/merkle/mmr/full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
use super::{
super::{InnerNodeInfo, MerklePath, RpoDigest, Vec},
bit::TrueBitPositionIterator,
MmrPeaks, MmrProof, Rpo256,
MmrError, MmrPeaks, MmrProof, Rpo256,
};
use core::fmt::{Display, Formatter};

#[cfg(feature = "std")]
use std::error::Error;

// MMR
// ===============================================================================================
Expand All @@ -43,22 +39,6 @@ pub struct Mmr {
pub(super) nodes: Vec<RpoDigest>,
}

#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum MmrError {
InvalidPosition(usize),
}

impl Display for MmrError {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
match self {
MmrError::InvalidPosition(pos) => write!(fmt, "Mmr does not contain position {pos}"),
}
}
}

#[cfg(feature = "std")]
impl Error for MmrError {}

impl Default for Mmr {
fn default() -> Self {
Self::new()
Expand Down Expand Up @@ -100,21 +80,16 @@ impl Mmr {
// find the target tree responsible for the MMR position
let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
let forest_target = 1usize << tree_bit;

// isolate the trees before the target
let forest_before = self.forest & high_bitmask(tree_bit + 1);
let index_offset = nodes_in_forest(forest_before);

// find the root
let index = nodes_in_forest(forest_target) - 1;

// update the value position from global to the target tree
let relative_pos = pos - forest_before;

// collect the path and the final index of the target value
let (_, path) =
self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset, index);
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);

Ok(MmrProof {
forest: self.forest,
Expand All @@ -132,21 +107,16 @@ impl Mmr {
// find the target tree responsible for the MMR position
let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
let forest_target = 1usize << tree_bit;

// isolate the trees before the target
let forest_before = self.forest & high_bitmask(tree_bit + 1);
let index_offset = nodes_in_forest(forest_before);

// find the root
let index = nodes_in_forest(forest_target) - 1;

// update the value position from global to the target tree
let relative_pos = pos - forest_before;

// collect the path and the final index of the target value
let (value, _) =
self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset, index);
let (value, _) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);

Ok(value)
}
Expand Down Expand Up @@ -185,7 +155,82 @@ impl Mmr {
.map(|offset| self.nodes[offset - 1])
.collect();

MmrPeaks { num_leaves: self.forest, peaks }
// Safety: the invariant is maintained by the [Mmr]
MmrPeaks::new(self.forest, peaks).unwrap()
}

/// Compute the required update to `original_forest`.
///
/// The result is a packed sequence of the authentication elements required to update the trees
/// that have been merged together, followed by the new peaks of the [Mmr].
pub fn updates(&self, original_forest: usize) -> Result<Vec<RpoDigest>, MmrError> {
if original_forest > self.forest {
return Err(MmrError::InvalidPeaks);
}

if original_forest == self.forest {
return Ok(Vec::new());
}

let mut result = Vec::new();

// Find the largest tree in this [Mmr] which is new to `original_forest`.
let candidate_trees = self.forest ^ original_forest;
let mut new_high = 1 << candidate_trees.ilog2();

// Collect authentication nodes used for to tree merges
// ----------------------------------------------------------------------------------------

// Find the trees from `original_forest` that have been merged into `new_high`.
let mut merges = original_forest & (new_high - 1);

// Find the peaks that are common to `original_forest` and this [Mmr]
let common_trees = original_forest ^ merges;

if merges != 0 {
// Skip the smallest trees unknown to `original_forest`.
let mut target = 1 << merges.trailing_zeros();

// Collect siblings required to computed the merged tree's peak
while target < new_high {
// Computes the offset to the smallest know peak
// - common_trees: peaks unchanged in the current update, target comes after these.
// - merges: peaks that have not been merged so far, target comes after these.
// - target: tree from which to load the sibling. On the first iteration this is a
// value known by the partial mmr, on subsequent iterations this value is to be
// computed from the known peaks and provided authentication nodes.
let known = nodes_in_forest(common_trees | merges | target);
let sibling = nodes_in_forest(target);
result.push(self.nodes[known + sibling - 1]);

// Update the target and account for tree merges
target <<= 1;
while merges & target != 0 {
target <<= 1;
}
// Remove the merges done so far
merges ^= merges & (target - 1);
}
} else {
// The new high tree may not be the result of any merges, if it is smaller than all the
// trees of `original_forest`.
new_high = 0;
}

// Collect the new [Mmr] peaks
// ----------------------------------------------------------------------------------------

let mut new_peaks = self.forest ^ common_trees ^ new_high;
let old_peaks = self.forest ^ new_peaks;
let mut offset = nodes_in_forest(old_peaks);
while new_peaks != 0 {
let target = 1 << new_peaks.ilog2();
offset += nodes_in_forest(target);
result.push(self.nodes[offset - 1]);
new_peaks ^= target;
}

Ok(result)
}

/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
Expand All @@ -202,36 +247,52 @@ impl Mmr {
// ============================================================================================

/// Internal function used to collect the Merkle path of a value.
///
/// The arguments are relative to the target tree. To compute the opening of the second leaf
/// for a tree with depth 2 in the forest `0b110`:
///
/// - `tree_bit`: Depth of the target tree, e.g. 2 for the smallest tree.
/// - `relative_pos`: 0-indexed leaf position in the target tree, e.g. 1 for the second leaf.
/// - `index_offset`: Node count prior to the target tree, e.g. 7 for the tree of depth 3.
fn collect_merkle_path_and_value(
&self,
tree_bit: u32,
relative_pos: usize,
index_offset: usize,
mut index: usize,
) -> (RpoDigest, Vec<RpoDigest>) {
// collect the Merkle path
let mut tree_depth = tree_bit as usize;
let mut path = Vec::with_capacity(tree_depth + 1);
while tree_depth > 0 {
let bit = relative_pos & tree_depth;
// see documentation of `leaf_to_corresponding_tree` for details
let tree_depth = (tree_bit + 1) as usize;
let mut path = Vec::with_capacity(tree_depth);

// The tree walk below goes from the root to the leaf, compute the root index to start
let mut forest_target = 1usize << tree_bit;
let mut index = nodes_in_forest(forest_target) - 1;

// Loop until the leaf is reached
while forest_target > 1 {
// Update the depth of the tree to correspond to a subtree
forest_target >>= 1;

// compute the indeces of the right and left subtrees based on the post-order
let right_offset = index - 1;
let left_offset = right_offset - nodes_in_forest(tree_depth);
let left_offset = right_offset - nodes_in_forest(forest_target);

// Elements to the right have a higher position because they were
// added later. Therefore when the bit is true the node's path is
// to the right, and its sibling to the left.
let sibling = if bit != 0 {
let left_or_right = relative_pos & forest_target;
let sibling = if left_or_right != 0 {
// going down the right subtree, the right child becomes the new root
index = right_offset;
// and the left child is the authentication
self.nodes[index_offset + left_offset]
} else {
index = left_offset;
self.nodes[index_offset + right_offset]
};

tree_depth >>= 1;
path.push(sibling);
}

debug_assert!(path.len() == tree_depth - 1);

// the rest of the codebase has the elements going from leaf to root, adjust it here for
// easy of use/consistency sake
path.reverse();
Expand Down Expand Up @@ -340,8 +401,9 @@ impl<'a> Iterator for MmrNodes<'a> {
///
/// Note:
/// The result is a tree position `p`, it has the following interpretations. $p+1$ is the depth of
/// the tree, which corresponds to the size of a Merkle proof for that tree. $2^p$ is equal to the
/// number of leaves in this particular tree. and $2^(p+1)-1$ corresponds to size of the tree.
/// the tree. Because the root element is not part of the proof, $p$ is the length of the
/// authentication path. $2^p$ is equal to the number of leaves in this particular tree. and
/// $2^(p+1)-1$ corresponds to size of the tree.
pub(crate) const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
if pos >= forest {
None
Expand Down
Loading

0 comments on commit 67c7422

Please sign in to comment.