Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handling of dict relationships in args gen #534

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cairo/ethereum/cancun/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ func close_transaction{
let transient_storage_tries_start = transient_storage_tries.value._data.value.dict_ptr_start;
let transient_storage_tries_end = transient_storage_tries.value._data.value.dict_ptr;
let parent_transient_storage_tries = transient_storage_tries.value._data.value.parent_dict;

with_attr error_message("IndexError") {
tempvar parent_transient_storage_tries_ptr = cast(parent_transient_storage_tries, felt);
if (cast(parent_transient_storage_tries_ptr, felt) == 0) {
Expand Down
20 changes: 9 additions & 11 deletions cairo/tests/ethereum/cancun/test_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Optional

import pytest
Expand Down Expand Up @@ -36,7 +37,7 @@
set_transient_storage,
touch_account,
)
from ethereum.cancun.trie import Trie
from ethereum.cancun.trie import Trie, copy_trie
from tests.utils.args_gen import State, TransientStorage, Withdrawal
from tests.utils.errors import strict_raises
from tests.utils.strategies import (
Expand Down Expand Up @@ -91,7 +92,7 @@ def state_with_snapshots(draw):

# Start with base state's tries
current_main_trie = base_state._main_trie
current_storage_tries = base_state._storage_tries.copy()
current_storage_tries = copy.deepcopy(base_state._storage_tries)
snapshots = []

for _ in range(num_snapshots):
Expand All @@ -100,11 +101,8 @@ def state_with_snapshots(draw):
new_accounts = draw(
st.dictionaries(keys=address, values=st.from_type(Account), max_size=5)
)
main_trie_data = current_main_trie._data.copy()
main_trie_data.update(new_accounts)
main_trie = Trie[Address, Optional[Account]](
secured=True, default=None, _data=main_trie_data
)
main_trie_copy = copy_trie(current_main_trie)
main_trie_copy._data.update(new_accounts)

# Add up to 5 new storage tries or update existing ones
new_storage_tries = draw(
Expand All @@ -114,11 +112,11 @@ def state_with_snapshots(draw):
max_size=5,
)
)
storage_tries = current_storage_tries.copy()
storage_tries = copy.deepcopy(current_storage_tries)
storage_tries.update(new_storage_tries)

# Update current state for next iteration
current_main_trie = main_trie
current_main_trie = main_trie_copy
current_storage_tries = storage_tries

return State(
Expand All @@ -139,7 +137,7 @@ def transient_storage_with_snapshots(draw):
num_snapshots = draw(st.integers(min_value=0, max_value=5))

# Start with base transient storage tries
current_tries = base_transient_storage._tries.copy()
current_tries = copy.deepcopy(base_transient_storage._tries)
snapshots = []

for _ in range(num_snapshots):
Expand All @@ -152,7 +150,7 @@ def transient_storage_with_snapshots(draw):
max_size=5,
)
)
tries = current_tries.copy()
tries = copy.deepcopy(current_tries)
tries.update(new_tries)

# Update current tries for next iteration
Expand Down
12 changes: 11 additions & 1 deletion cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def generate_trie_arg(
# In case of a Trie, we need the dict to be a defaultdict with the trie.default as the default value.
dict_ptr = segments.memory.get(data)
current_ptr = segments.memory.get(data + 1)

if isinstance(dict_manager, DictManager):
dict_manager.trackers[dict_ptr.segment_index].data = defaultdict(
lambda: default, dict_manager.trackers[dict_ptr.segment_index].data
Expand Down Expand Up @@ -947,16 +948,25 @@ def generate_dict_arg(
# This is required for tests where we read data from DictAccess segments while no dict method has been used.
# Equivalent to doing an initial dict_read of all keys.
# We only hash keys if they're in tuples.

# In case of a dict update, we need to get the prev_value from the dict_tracker of the parent_ptr.
# For consistency purposes when we drop the dict and put its prev values back in the parent_ptr.
parent_dict_end_ptr = segments.memory.get(parent_ptr + 1) if parent_ptr else None
initial_data = flatten(
[
(
(poseidon_hash_many(k) if get_args(arg_type)[0] in HASHED_TYPES else k),
v,
(
dict_manager.get_tracker(parent_dict_end_ptr).data.get(k, v)
if parent_dict_end_ptr
else v
),
v,
)
for k, v in data.items()
]
)

segments.load_data(dict_ptr, initial_data)
current_ptr = dict_ptr + len(initial_data)

Expand Down
15 changes: 15 additions & 0 deletions crates/cairo-addons/src/vm/dict_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ impl PyDictManager {
Ok(PyTrackerMapping { inner: self.inner.clone() })
}

fn get_tracker(&self, ptr: PyRelocatable) -> PyResult<PyDictTracker> {
self.inner
.borrow()
.trackers
.get(&ptr.inner.segment_index)
.cloned()
.map(|tracker| PyDictTracker { inner: tracker })
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"segment_index {} not found",
ptr.inner.segment_index
))
})
}

fn insert(&mut self, segment_index: isize, value: &PyDictTracker) -> PyResult<()> {
if self.inner.borrow().trackers.contains_key(&segment_index) {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
Expand Down
2 changes: 1 addition & 1 deletion python/cairo-addons/src/cairo_addons/hints/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def update_dict_tracker(
ap: RelocatableValue,
):
dict_tracker = dict_manager.get_tracker(ids.current_tracker_ptr)
dict_tracker.current_ptr = ids.new_tracker_ptr
dict_tracker.current_ptr = ids.new_tracker_ptr.address_
Loading