Skip to content

Commit

Permalink
fix some hints
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Jan 24, 2025
1 parent 6b29052 commit 1514230
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 22 deletions.
9 changes: 5 additions & 4 deletions cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,6 @@ func squash_and_update{range_check_ptr}(
return dst;
}

let dict_ptr = squashed_src_end;
let parent_dict_end = dst;
%{ merge_dict_tracker_with_parent %}

// Loop on all keys and write the new_value to the dst dict.
tempvar squashed_src = squashed_src_start;
tempvar dst_end = dst;
Expand Down Expand Up @@ -305,6 +301,11 @@ func squash_and_update{range_check_ptr}(
jmp loop;

done:
// Merge
let dict_ptr = squashed_src_end;
let parent_dict_end = dst;
%{ merge_dict_tracker_with_parent %}

let current_tracker_ptr = dst;
let new_tracker_ptr = cast([ap - 5], DictAccess*);
%{ update_dict_tracker %}
Expand Down
1 change: 1 addition & 0 deletions cairo/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def pytest_addoption(parser):
deadline=None,
max_examples=300,
phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target],
print_blob=True,
derandomize=True,
)
settings.register_profile(
Expand Down
17 changes: 14 additions & 3 deletions cairo/tests/src/utils/test_dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ from ethereum.cancun.state import (
ListTupleAddressBytes32,
ListTupleAddressBytes32Struct,
)
from src.utils.dict import prev_values, dict_update, get_keys_for_address_prefix, squash_and_update
from src.utils.dict import (
prev_values,
dict_update,
get_keys_for_address_prefix,
squash_and_update,
dict_squash,
)

func test_prev_values{range_check_ptr}() -> (prev_values_start_ptr: felt*) {
alloc_locals;
Expand Down Expand Up @@ -97,10 +103,15 @@ func test_squash_and_update{range_check_ptr}(
cast(src_start, DictAccess*), cast(src_end, DictAccess*), cast(dst, DictAccess*)
);

// Squash the dict another time to ensure that the update was done correctly
let (final_start, final_end) = dict_squash(
cast(dst_dict.value.dict_ptr_start, DictAccess*), new_dst_end
);

tempvar new_dst_dict = MappingTupleAddressBytes32U256(
new MappingTupleAddressBytes32U256Struct(
dict_ptr_start=cast(dst_dict.value.dict_ptr_start, TupleAddressBytes32U256DictAccess*),
dict_ptr=cast(new_dst_end, TupleAddressBytes32U256DictAccess*),
dict_ptr_start=cast(final_start, TupleAddressBytes32U256DictAccess*),
dict_ptr=cast(final_end, TupleAddressBytes32U256DictAccess*),
parent_dict=dst_dict.value.parent_dict,
),
);
Expand Down
7 changes: 6 additions & 1 deletion cairo/tests/src/utils/test_dict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import List, Mapping, Tuple

from cairo_addons.hints.decorator import register_hint
Expand Down Expand Up @@ -121,6 +122,10 @@ def test_squash_and_update(
src_dict: Mapping[Tuple[Address, Bytes32], U256],
dst_dict: Mapping[Tuple[Address, Bytes32], U256],
):
new_dst_dict = cairo_run("test_squash_and_update", src_dict, dst_dict)
new_dst_dict = cairo_run(
"test_squash_and_update",
defaultdict(lambda: U256(0), src_dict),
defaultdict(lambda: U256(0), dst_dict),
)
dst_dict.update(src_dict)
assert new_dst_dict == dst_dict
52 changes: 50 additions & 2 deletions crates/cairo-addons/src/vm/hint_definitions/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::collections::HashMap;

use cairo_vm::{
hint_processor::{
builtin_hint_processor::hint_utils::{
get_ptr_from_var_name, insert_value_from_var_name, insert_value_into_ap,
builtin_hint_processor::{
dict_manager::DictTracker,
hint_utils::{get_ptr_from_var_name, insert_value_from_var_name, insert_value_into_ap},
},
hint_processor_definition::HintReference,
},
Expand All @@ -18,6 +19,7 @@ use crate::vm::hints::Hint;
pub const HINTS: &[fn() -> Hint] = &[
dict_new_empty,
dict_squash,
dict_copy,
copy_dict_segment,
merge_dict_tracker_with_parent,
update_dict_tracker,
Expand Down Expand Up @@ -67,6 +69,52 @@ pub fn dict_squash() -> Hint {
)
}

pub fn dict_copy() -> Hint {
Hint::new(
String::from("dict_copy"),
|vm: &mut VirtualMachine,
exec_scopes: &mut ExecutionScopes,
ids_data: &HashMap<String, HintReference>,
ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>|
-> Result<(), HintError> {
// Get the new_start and dict_start pointers from ids
let new_start = get_ptr_from_var_name("new_start", vm, ids_data, ap_tracking)?;
let dict_start = get_ptr_from_var_name("dict_start", vm, ids_data, ap_tracking)?;
let new_end = get_ptr_from_var_name("new_end", vm, ids_data, ap_tracking)?;

let dict_manager_ref = exec_scopes.get_dict_manager()?;
let mut dict_manager = dict_manager_ref.borrow_mut();

// Check if new segment already exists in trackers
// Get and copy data from the source dictionary
let source_tracker = dict_manager.trackers.get(&dict_start.segment_index).ok_or(
HintError::CustomHint(Box::from(format!(
"Segment {} already exists in dict_manager.trackers",
new_start.segment_index
))),
)?;
let copied_data = source_tracker.get_dictionary_copy();
let default_value = source_tracker.get_default_value().cloned();

// Create new tracker with copied data
if let Some(default_value) = default_value {
dict_manager.trackers.insert(
new_end.segment_index,
DictTracker::new_default_dict(new_end, &default_value, Some(copied_data)),
);
} else {
dict_manager.trackers.insert(
new_end.segment_index,
DictTracker::new_with_initial(new_end, copied_data),
);
}

Ok(())
},
)
}

pub fn copy_dict_segment() -> Hint {
Hint::new(
String::from("copy_dict_segment"),
Expand Down
14 changes: 9 additions & 5 deletions crates/cairo-addons/src/vm/hint_definitions/hashdict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,14 @@ pub fn hashdict_read_from_key() -> Hint {
// keys.
let simple_key = DictKey::Simple(hashed_key.into());
let preimage =
get_preimage_for_hashed_key(hashed_key, tracker).unwrap_or(&simple_key).clone();
_get_preimage_for_hashed_key(hashed_key, tracker).unwrap_or(&simple_key).clone();
let value = tracker
.get_value(&preimage)
.map_err(|_| HintError::CustomHint("No value found for preimage".into()))?
.map_err(|_| {
HintError::CustomHint(
format!("No value found for preimage {}", preimage).into(),
)
})?
.clone();

// Set the value
Expand Down Expand Up @@ -286,7 +290,7 @@ pub fn get_preimage_for_key() -> Hint {
let tracker = dict.get_tracker(dict_ptr)?;

// Find matching preimage
let preimage = get_preimage_for_hashed_key(hashed_key, tracker)?;
let preimage = _get_preimage_for_hashed_key(hashed_key, tracker)?;

// Write preimage data to memory
let preimage_data_ptr =
Expand Down Expand Up @@ -330,7 +334,7 @@ pub fn copy_hashdict_tracker_entry() -> Hint {

// Find matching preimage from source tracker data
let key_hash = get_integer_from_var_name("source_key", vm, ids_data, ap_tracking)?;
let preimage = get_preimage_for_hashed_key(key_hash, source_tracker)?.clone();
let preimage = _get_preimage_for_hashed_key(key_hash, source_tracker)?.clone();
let value = source_tracker
.get_value(&preimage)
.map_err(|_| {
Expand Down Expand Up @@ -367,7 +371,7 @@ fn build_compound_key(
}

/// Helper function to find a preimage in a tracker's dictionary given a hashed key
fn get_preimage_for_hashed_key(
fn _get_preimage_for_hashed_key(
hashed_key: Felt252,
tracker: &DictTracker,
) -> Result<&DictKey, HintError> {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dev-dependencies = [
"eth-account>=0.13.3",
"eth-keys>=0.5.1",
"eth-utils>=5.0.0",
"hypothesis>=6.123.17",
"hypothesis>=6.124.3",
"ipykernel>=6.29.5",
"pytest-xdist>=3.6.1",
"pytest>=8.3.3",
Expand Down
8 changes: 6 additions & 2 deletions python/cairo-addons/src/cairo_addons/hints/hashdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ def hashdict_read_from_key(
from cairo_addons.hints.hashdict import _get_preimage_for_hashed_key

dict_tracker = dict_manager.get_tracker(ids.dict_ptr_stop)
preimage = _get_preimage_for_hashed_key(ids.key, dict_tracker) or ids.key
ids.value = dict_tracker.data[preimage]
try:
preimage = _get_preimage_for_hashed_key(ids.key, dict_tracker) or ids.key
except Exception:
ids.value = dict_tracker.data.default_factory()
else:
ids.value = dict_tracker.data[preimage]


@register_hint
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1514230

Please sign in to comment.