Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce trace processing for multiple storage tries #157

Closed
wants to merge 11 commits into from
4 changes: 2 additions & 2 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
* @muursh @Nashtare
/evm_arithmetization/ @wborgeaud @muursh @Nashtare
* @muursh @Nashtare @cpubot
/evm_arithmetization/ @wborgeaud @muursh @Nashtare @cpubot
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Changed
- Add a few QoL useability functions to the interface ([#169](https://github.com/0xPolygonZero/zk_evm/pull/169))

## [0.3.1] - 2024-04-22

Expand Down
44 changes: 26 additions & 18 deletions evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,22 @@ global init_access_lists:
POP
%endmacro

// Multiply the ptr at the top of the stack by 2
// and abort if 2*ptr - @SEGMENT_ACCESSED_ADDRESSES >= @GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN
// In this way ptr must be pointing to the begining of a node.
// Multiply the value at the top of the stack, denoted by ptr/2, by 2
// and abort if ptr/2 >= mem[@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN]/2
// In this way 2*ptr/2 must be pointing to the begining of a node.
%macro get_valid_addr_ptr
// stack: ptr
// stack: ptr/2
DUP1
// stack: ptr/2, ptr/2
%mload_global_metadata(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN)
// @GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN must be an even number because
// both @SEGMENT_ACCESSED_ADDRESSES and the unscaled access addresses list len
// must be even numbers
%div_const(2)
// stack: scaled_len/2, ptr/2, ptr/2
%assert_gt
%mul_const(2)
PUSH @SEGMENT_ACCESSED_ADDRESSES
DUP2
SUB
%assert_lt_const(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN)
// stack: 2*ptr
// stack: ptr
%endmacro


Expand Down Expand Up @@ -205,17 +210,20 @@ global remove_accessed_addresses:
// stack: cold_access, value_ptr
%endmacro

// Multiply the ptr at the top of the stack by 4
// and abort if 4*ptr - SEGMENT_ACCESSED_STORAGE_KEYS >= @GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN
// In this way ptr must be pointing to the beginning of a node.
// Multiply the ptr at the top of the stack, denoted by ptr/4, by 4
// and abort if ptr/4 >= @GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN/4
// In this way 4*ptr/4 be pointing to the beginning of a node.
%macro get_valid_storage_ptr
// stack: ptr
// stack: ptr/4
DUP1
%mload_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN)
// By construction, both @SEGMENT_ACCESSED_STORAGE_KEYS and the unscaled list len
// must be multiples of 4
%div_const(4)
// stack: scaled_len/4, ptr/4, ptr/4
%assert_gt
%mul_const(4)
PUSH @SEGMENT_ACCESSED_STORAGE_KEYS
DUP2
SUB
%assert_lt_const(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN)
// stack: 2*ptr
// stack: ptr
%endmacro

/// Inserts the storage key into the access list if it is not already present.
Expand Down
10 changes: 4 additions & 6 deletions evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ global panic:
%endmacro

%macro assert_lt(ret)
GE
%assert_zero($ret)
LT
%assert_nonzero($ret)
%endmacro

%macro assert_le
Expand All @@ -56,10 +56,8 @@ global panic:
%endmacro

%macro assert_gt
// %assert_zero is cheaper than %assert_nonzero, so we will leverage the
// fact that (x > y) == !(x <= y).
LE
%assert_zero
GT
%assert_nonzero
%endmacro

%macro assert_gt(ret)
Expand Down
58 changes: 57 additions & 1 deletion mpt_trie/src/nibbles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub enum BytesToNibblesError {
TooManyBytes(usize),
}

#[derive(Debug, Error)]
#[derive(Clone, Debug, Error)]
/// Errors encountered when converting to hex prefix encoding to nibbles.
pub enum FromHexPrefixError {
#[error("Tried to convert a hex prefix byte string into `Nibbles` with invalid flags at the start: {0:#04b}")]
Expand Down Expand Up @@ -118,6 +118,7 @@ macro_rules! impl_as_u64s_for_primitive {
};
}

impl_as_u64s_for_primitive!(usize);
impl_as_u64s_for_primitive!(u8);
impl_as_u64s_for_primitive!(u16);
impl_as_u64s_for_primitive!(u32);
Expand Down Expand Up @@ -178,6 +179,7 @@ macro_rules! impl_to_nibbles {
};
}

impl_to_nibbles!(usize);
impl_to_nibbles!(u8);
impl_to_nibbles!(u16);
impl_to_nibbles!(u32);
Expand Down Expand Up @@ -908,6 +910,23 @@ impl Nibbles {
}
}

/// Returns a slice of the internal bytes of packed nibbles.
/// Only the relevant bytes (up to `count` nibbles) are considered valid.
pub fn as_byte_slice(&self) -> &[u8] {
// Calculate the number of full bytes needed to cover 'count' nibbles
let bytes_needed = (self.count + 1) / 2; // each nibble is half a byte

// Safe because we are ensuring the slice size does not exceed the bounds of the
// array
unsafe {
// Convert the pointer to `packed` to a pointer to `u8`
let packed_ptr = self.packed.0.as_ptr() as *const u8;

// Create a slice from this pointer and the number of needed bytes
std::slice::from_raw_parts(packed_ptr, bytes_needed)
}
}

const fn nibble_append_safety_asserts(&self, n: Nibble) {
assert!(
self.count < 64,
Expand Down Expand Up @@ -1616,6 +1635,12 @@ mod tests {
format!("{:x}", 0x1234_u64.to_nibbles_byte_padded()),
"0x1234"
);

assert_eq!(format!("{:x}", 0x1234_usize.to_nibbles()), "0x1234");
assert_eq!(
format!("{:x}", 0x1234_usize.to_nibbles_byte_padded()),
"0x1234"
);
}

#[test]
Expand All @@ -1627,4 +1652,35 @@ mod tests {

Nibbles::from_hex_prefix_encoding(&buf).unwrap();
}

#[test]
fn nibbles_as_byte_slice_works() -> Result<(), StrToNibblesError> {
let cases = [
(0x0, vec![]),
(0x1, vec![0x01]),
(0x12, vec![0x12]),
(0x123, vec![0x23, 0x01]),
];

for case in cases.iter() {
let nibbles = Nibbles::from(case.0 as u64);
let byte_vec = nibbles.as_byte_slice().to_vec();
assert_eq!(byte_vec, case.1.clone(), "Failed for input 0x{:X}", case.0);
}

let input = "3ab76c381c0f8ea617ea96780ffd1e165c754b28a41a95922f9f70682c581351";
let nibbles = Nibbles::from_str(input)?;

let byte_vec = nibbles.as_byte_slice().to_vec();
let mut expected_vec: Vec<u8> = hex::decode(input).expect("Invalid hex string");
expected_vec.reverse();
assert_eq!(
byte_vec,
expected_vec.clone(),
"Failed for input 0x{}",
input
);

Ok(())
}
}
19 changes: 19 additions & 0 deletions mpt_trie/src/partial_trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ pub trait PartialTrie:
/// Returns an iterator over the trie that returns all values for every
/// `Leaf` and `Hash` node.
fn values(&self) -> impl Iterator<Item = ValOrHash>;

/// Returns `true` if the trie contains an element with the given key.
fn contains<K>(&self, k: K) -> bool
where
K: Into<Nibbles>;
}

/// Part of the trait that is not really part of the public interface but
Expand Down Expand Up @@ -261,6 +266,13 @@ impl PartialTrie for StandardTrie {
fn values(&self) -> impl Iterator<Item = ValOrHash> {
self.0.trie_values()
}

fn contains<K>(&self, k: K) -> bool
where
K: Into<Nibbles>,
{
self.0.trie_has_item_by_key(k)
}
}

impl TrieNodeIntern for StandardTrie {
Expand Down Expand Up @@ -381,6 +393,13 @@ impl PartialTrie for HashedPartialTrie {
fn values(&self) -> impl Iterator<Item = ValOrHash> {
self.node.trie_values()
}

fn contains<K>(&self, k: K) -> bool
where
K: Into<Nibbles>,
{
self.node.trie_has_item_by_key(k)
}
}

impl TrieNodeIntern for HashedPartialTrie {
Expand Down
34 changes: 32 additions & 2 deletions mpt_trie/src/trie_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
pub type TrieOpResult<T> = Result<T, TrieOpError>;

/// An error type for trie operation.
#[derive(Debug, Error)]
#[derive(Clone, Debug, Error)]
pub enum TrieOpError {
/// An error that occurs when a hash node is found during an insert
/// operation.
Expand Down Expand Up @@ -364,7 +364,7 @@ impl<T: PartialTrie> Node<T> {
where
K: Into<Nibbles>,
{
let k = k.into();
let k: Nibbles = k.into();
trace!("Deleting a leaf node with key {} if it exists", k);

delete_intern(&self.clone(), k)?.map_or(Ok(None), |(updated_root, deleted_val)| {
Expand All @@ -391,6 +391,14 @@ impl<T: PartialTrie> Node<T> {
pub(crate) fn trie_values(&self) -> impl Iterator<Item = ValOrHash> {
self.trie_items().map(|(_, v)| v)
}

pub(crate) fn trie_has_item_by_key<K>(&self, k: K) -> bool
where
K: Into<Nibbles>,
{
let k = k.into();
self.trie_items().any(|(key, _)| key == k)
}
}

fn insert_into_trie_rec<N: PartialTrie>(
Expand Down Expand Up @@ -1105,6 +1113,28 @@ mod tests {
Ok(())
}

#[test]
fn existent_node_key_contains_returns_true() -> TrieOpResult<()> {
common_setup();

let mut trie = StandardTrie::default();
trie.insert(0x1234, vec![91])?;
assert!(trie.contains(0x1234));

Ok(())
}

#[test]
fn non_existent_node_key_contains_returns_false() -> TrieOpResult<()> {
common_setup();

let mut trie = StandardTrie::default();
trie.insert(0x1234, vec![91])?;
assert!(!trie.contains(0x5678));

Ok(())
}

#[test]
fn deleting_from_an_empty_trie_returns_none() -> TrieOpResult<()> {
common_setup();
Expand Down
18 changes: 18 additions & 0 deletions mpt_trie/src/trie_subsets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,24 @@ mod tests {
Ok(())
}

#[test]
fn sub_trie_existent_key_contains_returns_true() {
let trie = create_trie_with_large_entry_nodes(&[0x0]).unwrap();

let partial_trie = create_trie_subset(&trie, [0x1234]).unwrap();

assert!(partial_trie.contains(0x0));
}

#[test]
fn sub_trie_non_existent_key_contains_returns_false() {
let trie = create_trie_with_large_entry_nodes(&[0x0]).unwrap();

let partial_trie = create_trie_subset(&trie, [0x1234]).unwrap();

assert!(!partial_trie.contains(0x1));
}

fn assert_all_keys_do_not_exist(trie: &TrieType, ks: impl Iterator<Item = Nibbles>) {
for k in ks {
assert!(trie.get(k).is_none());
Expand Down
4 changes: 3 additions & 1 deletion trace_decoder/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ thiserror = { workspace = true }

# Local dependencies
mpt_trie = { version = "0.2.1", path = "../mpt_trie" }
evm_arithmetization = { version = "0.1.3", path = "../evm_arithmetization" }

evm_arithmetization_mpt = { package = "evm_arithmetization", version = "0.1.3", path = "../evm_arithmetization" }
evm_arithmetization_smt = { package = "evm_arithmetization", version = "0.1.2", git = "https://github.com/0xPolygonZero/zk_evm.git", branch = "feat_type2_hack" }

[dev-dependencies]
pretty_env_logger = "0.5.0"
48 changes: 48 additions & 0 deletions trace_decoder/src/aliased_crate_types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//! This library links against two versions of the dependencies in `zk_evm`, and
//! this module handles the logic to work with both at once.
//!
//! Currently (this may change in the future), but SMT support is currently on
//! its own separate branch in `zk_evm`. We want to be able to support both
//! `MPT` (on the `main` branch) and SMT (on the `feat/type2` branch) in a
//! single binary. Because `feat/type2` modifies existing types that `main`
//! uses, we can not just simply use imports from both branches at the same
//! time. Instead, we need to make each version of the packages their own
//! separate dependency. This module just aliases the types to make them a bit
//! more readable, while also making it easier to merge the libraries together
//! later if the `feat/type2` eventually gets merged back into main.

macro_rules! create_aliased_type {
($alias_name:ident, $path:path) => {
pub(crate) type $alias_name = $path;
};
}

// MPT imports
create_aliased_type!(
MptAccountRlp,
evm_arithmetization_mpt::generation::mpt::AccountRlp
);
create_aliased_type!(MptBlockHashes, evm_arithmetization_mpt::proof::BlockHashes);
create_aliased_type!(
MptBlockMetadata,
evm_arithmetization_mpt::proof::BlockMetadata
);
create_aliased_type!(
MptExtraBlockData,
evm_arithmetization_mpt::proof::ExtraBlockData
);
create_aliased_type!(
MptGenerationInputs,
evm_arithmetization_mpt::generation::GenerationInputs
);
create_aliased_type!(
MptTrieInputs,
evm_arithmetization_mpt::generation::TrieInputs
);
create_aliased_type!(MptTrieRoots, evm_arithmetization_mpt::proof::TrieRoots);

// SMT imports
create_aliased_type!(
SmtGenerationInputs,
evm_arithmetization_smt::generation::GenerationInputs
);
Loading