Skip to content

Commit

Permalink
feat: trie copy (#385)
Browse files Browse the repository at this point in the history
Closes #361
  • Loading branch information
enitrat authored Jan 13, 2025
1 parent 6e0e8a4 commit 9cb4ba4
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 116 deletions.
6 changes: 6 additions & 0 deletions cairo/ethereum/cancun/fork_types.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ struct AddressAccountDictAccess {
struct MappingAddressAccountStruct {
dict_ptr_start: AddressAccountDictAccess*,
dict_ptr: AddressAccountDictAccess*,
// In case this is a copy of a previous dict,
// this field points to the address of the original mapping.
original_mapping: MappingAddressAccountStruct*,
}

struct MappingAddressAccount {
Expand All @@ -105,6 +108,9 @@ struct Bytes32U256DictAccess {
struct MappingBytes32U256Struct {
dict_ptr_start: Bytes32U256DictAccess*,
dict_ptr: Bytes32U256DictAccess*,
// In case this is a copy of a previous dict,
// this field points to the address of the original mapping.
original_mapping: MappingBytes32U256Struct*,
}

struct MappingBytes32U256 {
Expand Down
3 changes: 3 additions & 0 deletions cairo/ethereum/cancun/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ struct AddressTrieBytes32U256DictAccess {
struct MappingAddressTrieBytes32U256Struct {
dict_ptr_start: AddressTrieBytes32U256DictAccess*,
dict_ptr: AddressTrieBytes32U256DictAccess*,
// In case this is a copy of a previous dict,
// this field points to the address of the original mapping.
original_mapping: MappingAddressTrieBytes32U256Struct*,
}

struct MappingAddressTrieBytes32U256 {
Expand Down
103 changes: 81 additions & 22 deletions cairo/ethereum/cancun/trie.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -331,24 +331,67 @@ func encode_node{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: Kecc
return encoded;
}

// func copy_trie(trie: Trie[K, V]) -> Trie[K, V] {
// // Implementation:
// // return Trie(trie.secured, trie.default, copy.copy(trie._data))
// }
// @notice Copies the trie to a new segment.
// @dev This function simply creates a new segment for the new dict and associates it with the
// dict_tracker of the source dict.
func copy_trieAddressAccount{range_check_ptr, trie: TrieAddressAccount}() -> TrieAddressAccount {
alloc_locals;
// TODO: soundness
// We need to ensure it is sound when finalizing that copy.
// The full design is:
// - We create a new segment for the new dict
// - We copy the python dict tracker and associate it with that new segment
// - When interacting with the copied trie, we use the new segment with the new dict_ptr
// - If the state reverts, then upon squashing that copy, we:
// - copy all the prev_keys in the new segment to the main segment (as if they were read from the new segment)
// - delete the new segment
// - This ensures that when squashing the main segment, we ensure that the data read in the new segment matched the data from the main segment.

local new_dict_ptr: AddressAccountDictAccess*;
tempvar original_mapping = trie.value._data.value;
%{
dict_tracker = __dict_manager.get_tracker(ids.original_mapping.dict_ptr)
copied_data = dict_tracker.data
ids.new_dict_ptr = __dict_manager.new_dict(segments, copied_data)
%}

// func trie_set(trie: Trie[K, V], key: K, value: V) {
// // Implementation:
// // if value == trie.default:
// // if key in trie._data:
// // del trie._data[key]
// // else:
// // trie._data[key] = value
// // if key in trie._data:
// // del trie._data[key]
// // del trie._data[key]
// // else:
// // trie._data[key] = value
// }
tempvar res = TrieAddressAccount(
new TrieAddressAccountStruct(
trie.value.secured,
trie.value.default,
MappingAddressAccount(
new MappingAddressAccountStruct(new_dict_ptr, new_dict_ptr, original_mapping)
),
),
);
return res;
}

func copy_trieBytes32U256{range_check_ptr, trie: TrieBytes32U256}() -> TrieBytes32U256 {
alloc_locals;
// TODO: same as above

local new_dict_ptr: Bytes32U256DictAccess*;
tempvar original_mapping = trie.value._data.value;
%{
from starkware.cairo.lang.vm.crypto import poseidon_hash_many
dict_tracker = __dict_manager.get_tracker(ids.original_mapping.dict_ptr)
copied_data = dict_tracker.data
ids.new_dict_ptr = __dict_manager.new_dict(segments, copied_data)
%}

tempvar res = TrieBytes32U256(
new TrieBytes32U256Struct(
trie.value.secured,
trie.value.default,
MappingBytes32U256(
new MappingBytes32U256Struct(new_dict_ptr, new_dict_ptr, original_mapping)
),
),
);
return res;
}

func trie_get_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressAccount}(
key: Address
Expand All @@ -363,8 +406,11 @@ func trie_get_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddre
let (pointer) = hashdict_read(1, &key.value);
}
let new_dict_ptr = cast(dict_ptr, AddressAccountDictAccess*);
let original_mapping = trie.value._data.value.original_mapping;
tempvar mapping = MappingAddressAccount(
new MappingAddressAccountStruct(trie.value._data.value.dict_ptr_start, new_dict_ptr)
new MappingAddressAccountStruct(
trie.value._data.value.dict_ptr_start, new_dict_ptr, original_mapping
),
);
tempvar trie = TrieAddressAccount(
new TrieAddressAccountStruct(trie.value.secured, trie.value.default, mapping)
Expand All @@ -382,8 +428,11 @@ func trie_get_TrieBytes32U256{poseidon_ptr: PoseidonBuiltin*, trie: TrieBytes32U
let (pointer) = hashdict_read(2, cast(key.value, felt*));
}
let new_dict_ptr = cast(dict_ptr, Bytes32U256DictAccess*);
let original_mapping = trie.value._data.value.original_mapping;
tempvar mapping = MappingBytes32U256(
new MappingBytes32U256Struct(trie.value._data.value.dict_ptr_start, new_dict_ptr)
new MappingBytes32U256Struct(
trie.value._data.value.dict_ptr_start, new_dict_ptr, original_mapping
),
);
tempvar trie = TrieBytes32U256(
new TrieBytes32U256Struct(trie.value.secured, trie.value.default, mapping)
Expand Down Expand Up @@ -418,7 +467,11 @@ func trie_set_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddre
}
let new_dict_ptr = cast(dict_ptr, AddressAccountDictAccess*);
tempvar mapping = MappingAddressAccount(
new MappingAddressAccountStruct(trie.value._data.value.dict_ptr_start, new_dict_ptr)
new MappingAddressAccountStruct(
trie.value._data.value.dict_ptr_start,
new_dict_ptr,
trie.value._data.value.original_mapping,
),
);
tempvar trie = TrieAddressAccount(
new TrieAddressAccountStruct(trie.value.secured, trie.value.default, mapping)
Expand Down Expand Up @@ -449,7 +502,11 @@ func trie_set_TrieBytes32U256{poseidon_ptr: PoseidonBuiltin*, trie: TrieBytes32U
}
let new_dict_ptr = cast(dict_ptr, Bytes32U256DictAccess*);
tempvar mapping = MappingBytes32U256(
new MappingBytes32U256Struct(trie.value._data.value.dict_ptr_start, new_dict_ptr)
new MappingBytes32U256Struct(
trie.value._data.value.dict_ptr_start,
new_dict_ptr,
trie.value._data.value.original_mapping,
),
);
tempvar trie = TrieBytes32U256(
new TrieBytes32U256Struct(trie.value.secured, trie.value.default, mapping)
Expand Down Expand Up @@ -777,7 +834,9 @@ func _get_branch_for_nibble_at_level{poseidon_ptr: PoseidonBuiltin*}(
obj.value.dict_ptr_start, dict_ptr_stop, branch_start, nibble, level, empty_value
);

tempvar result = MappingBytesBytes(new MappingBytesBytesStruct(branch_start, branch_ptr));
tempvar result = MappingBytesBytes(
new MappingBytesBytesStruct(branch_start, branch_ptr, cast(0, MappingBytesBytesStruct*))
);

return (result, value);
}
Expand Down
3 changes: 3 additions & 0 deletions cairo/ethereum_types/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ struct BytesBytesDictAccess {
struct MappingBytesBytesStruct {
dict_ptr_start: BytesBytesDictAccess*,
dict_ptr: BytesBytesDictAccess*,
// In case this is a copy of a previous dict,
// this field points to the address of the original mapping.
original_mapping: MappingBytesBytesStruct*,
}

struct MappingBytesBytes {
Expand Down
16 changes: 16 additions & 0 deletions cairo/tests/ethereum/cancun/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Trie,
bytes_to_nibble_list,
common_prefix_length,
copy_trie,
encode_internal_node,
encode_node,
nibble_list_to_compact,
Expand Down Expand Up @@ -186,3 +187,18 @@ def test_trie_set_TrieBytes32U256(
cairo_trie = cairo_run("trie_set_TrieBytes32U256", trie, key, value)
trie_set(trie, key, value)
assert cairo_trie == trie

@given(trie=...)
def test_copy_trie_AddressAccount(
self, cairo_run, trie: Trie[Address, Optional[Account]]
):
[original_trie, copied_trie] = cairo_run("copy_trieAddressAccount", trie)
trie_copy_py = copy_trie(trie)
assert original_trie == trie
assert copied_trie == trie_copy_py

@given(trie=...)
def test_copy_trie_Bytes32U256(self, cairo_run, trie: Trie[Bytes32, U256]):
[original_trie, copied_trie] = cairo_run("copy_trieBytes32U256", trie)
copy_trie(trie)
assert original_trie == trie
11 changes: 10 additions & 1 deletion cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,16 @@ def _gen_arg(
),
)
base = segments.add()
segments.load_data(base, [dict_ptr, current_ptr])

# The last element is the original_segment_stop pointer.
# Because this is a new dict, this is 0 (null ptr).
# This does not apply to stack and memory (hash_mode=False), in which case there's only 2 elements.
data_to_load = (
[dict_ptr, current_ptr, 0]
if (hash_mode is not False)
else [dict_ptr, current_ptr]
)
segments.load_data(base, data_to_load)
return base

if arg_type in (Union[int, RustRelocatable], Union[int, RelocatableValue]):
Expand Down
Loading

0 comments on commit 9cb4ba4

Please sign in to comment.