From aa5f331f9bcf6be204c28389593b9896a193427e Mon Sep 17 00:00:00 2001 From: enitrat Date: Fri, 10 Jan 2025 19:45:30 +0100 Subject: [PATCH] feat: make rust API more consistent with Cairo and accept compound dict keys --- Cargo.lock | 2 +- Cargo.toml | 2 +- .../tests/ethereum/cancun/test_fork_types.py | 4 - .../ethereum/cancun/utils/test_address.py | 3 - .../cancun/vm/instructions/test_arithmetic.py | 2 - .../cancun/vm/instructions/test_bitwise.py | 2 - .../cancun/vm/instructions/test_block.py | 2 - .../cancun/vm/instructions/test_comparison.py | 2 - .../vm/instructions/test_control_flow.py | 2 - .../cancun/vm/instructions/test_keccak.py | 2 - .../cancun/vm/instructions/test_log.py | 2 - .../instructions/test_memory_instructions.py | 2 - cairo/tests/ethereum/cancun/vm/test_memory.py | 4 - .../tests/ethereum/cancun/vm/test_runtime.py | 3 - cairo/tests/ethereum/cancun/vm/test_stack.py | 2 - cairo/tests/ethereum/crypto/test_hash.py | 3 - cairo/tests/ethereum/test_rlp.py | 2 - cairo/tests/src/test_stack.py | 5 - cairo/tests/test_serde.py | 3 - cairo/tests/utils/args_gen.py | 40 +++--- cairo/tests/utils/serde.py | 4 +- crates/cairo-addons/src/vm/dict_manager.rs | 131 ++++++++++++++++-- crates/cairo-addons/src/vm/felt.rs | 2 +- .../cairo-addons/src/vm/maybe_relocatable.rs | 2 +- crates/cairo-addons/src/vm/memory_segments.rs | 22 ++- crates/cairo-addons/src/vm/relocatable.rs | 2 +- .../cairo-addons/tests/test_dict_manager.py | 39 ++++-- .../tests/test_memory_segments.py | 5 + 28 files changed, 198 insertions(+), 98 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3513df73..6acf95a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,7 +145,7 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cairo-vm" version = "2.0.0-rc3" -source = "git+https://github.com/ClementWalter/cairo-rs#5f29188c8570470bbebce7bf82697a6d6a5e0724" +source = "git+https://github.com/kkrt-labs/cairo-vm?rev=dd3d3f6e76248fc02395b31cbefe5fe8183222f1#dd3d3f6e76248fc02395b31cbefe5fe8183222f1" dependencies = [ "anyhow", "arbitrary", diff --git a/Cargo.toml b/Cargo.toml index 2d87b0d7..7d9ac5d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,4 +59,4 @@ cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm.git", tag = "v2.0.0- ] } [patch."https://github.com/lambdaclass/cairo-vm.git"] -cairo-vm = { git = "https://github.com/ClementWalter/cairo-rs" } +cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "dd3d3f6e76248fc02395b31cbefe5fe8183222f1" } diff --git a/cairo/tests/ethereum/cancun/test_fork_types.py b/cairo/tests/ethereum/cancun/test_fork_types.py index 3a92ae88..be33a3de 100644 --- a/cairo/tests/ethereum/cancun/test_fork_types.py +++ b/cairo/tests/ethereum/cancun/test_fork_types.py @@ -1,9 +1,5 @@ -import pytest - from ethereum.cancun.fork_types import EMPTY_ACCOUNT -pytestmark = pytest.mark.python_vm - class TestForkTypes: def test_account_default(self, cairo_run): diff --git a/cairo/tests/ethereum/cancun/utils/test_address.py b/cairo/tests/ethereum/cancun/utils/test_address.py index 56f9a016..9d151b0e 100644 --- a/cairo/tests/ethereum/cancun/utils/test_address.py +++ b/cairo/tests/ethereum/cancun/utils/test_address.py @@ -1,6 +1,5 @@ from typing import Union -import pytest from ethereum_types.bytes import Bytes32 from ethereum_types.numeric import U256, Uint from hypothesis import given @@ -12,8 +11,6 @@ to_address, ) -pytestmark = pytest.mark.python_vm - class TestAddress: diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py b/cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py index 2826512a..a62ffe87 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py @@ -17,8 +17,6 @@ from tests.utils.args_gen import Evm from tests.utils.strategies import evm_lite -pytestmark = pytest.mark.python_vm - class TestArithmetic: @given(evm=evm_lite) diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py b/cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py index bbe75c7f..4ce6dbae 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py @@ -15,8 +15,6 @@ from tests.utils.args_gen import Evm from tests.utils.strategies import evm_lite -pytestmark = pytest.mark.python_vm - class TestBitwise: @given(evm=evm_lite) diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_block.py b/cairo/tests/ethereum/cancun/vm/instructions/test_block.py index 34d799fa..5d70bec4 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_block.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_block.py @@ -14,8 +14,6 @@ from tests.utils.args_gen import Evm from tests.utils.strategies import evm_lite -pytestmark = pytest.mark.python_vm - class TestBlock: @given(evm=evm_lite) diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_comparison.py b/cairo/tests/ethereum/cancun/vm/instructions/test_comparison.py index df740835..5db8838f 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_comparison.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_comparison.py @@ -13,8 +13,6 @@ from tests.utils.args_gen import Evm from tests.utils.strategies import evm_lite -pytestmark = pytest.mark.python_vm - class TestComparison: @given(evm=evm_lite) diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py b/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py index 159a10f3..5e102e95 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py @@ -15,8 +15,6 @@ from tests.utils.args_gen import Evm from tests.utils.strategies import evm_lite -pytestmark = pytest.mark.python_vm - class TestControlFlow: @given(evm=evm_lite) diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_keccak.py b/cairo/tests/ethereum/cancun/vm/instructions/test_keccak.py index a108125c..06d0aafc 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_keccak.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_keccak.py @@ -8,8 +8,6 @@ from tests.utils.args_gen import Evm from tests.utils.strategies import evm_lite, memory_access_size, memory_start_position -pytestmark = pytest.mark.python_vm - class TestKeccak: @given( diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_log.py b/cairo/tests/ethereum/cancun/vm/instructions/test_log.py index 7627297d..cc516dfb 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_log.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_log.py @@ -8,8 +8,6 @@ from tests.utils.args_gen import Evm from tests.utils.strategies import evm_lite, memory_access_size, memory_start_position -pytestmark = pytest.mark.python_vm - class TestLog: @given( diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_memory_instructions.py b/cairo/tests/ethereum/cancun/vm/instructions/test_memory_instructions.py index 5e78066d..fcbf509f 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_memory_instructions.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_memory_instructions.py @@ -6,8 +6,6 @@ from tests.utils.args_gen import Evm from tests.utils.strategies import evm_lite -pytestmark = pytest.mark.python_vm - class TestMemory: @given(evm=evm_lite) diff --git a/cairo/tests/ethereum/cancun/vm/test_memory.py b/cairo/tests/ethereum/cancun/vm/test_memory.py index 5ea35a2a..f88a912b 100644 --- a/cairo/tests/ethereum/cancun/vm/test_memory.py +++ b/cairo/tests/ethereum/cancun/vm/test_memory.py @@ -1,4 +1,3 @@ -import pytest from ethereum_types.bytes import Bytes from ethereum_types.numeric import U256 from hypothesis import given @@ -38,9 +37,6 @@ def memory_read_strategy(draw): return memory, start_position, size -pytestmark = pytest.mark.python_vm - - class TestMemory: @given(memory_write_strategy()) def test_memory_write(self, cairo_run, params): diff --git a/cairo/tests/ethereum/cancun/vm/test_runtime.py b/cairo/tests/ethereum/cancun/vm/test_runtime.py index 7a990f3a..871339df 100644 --- a/cairo/tests/ethereum/cancun/vm/test_runtime.py +++ b/cairo/tests/ethereum/cancun/vm/test_runtime.py @@ -1,11 +1,8 @@ -import pytest from ethereum_types.bytes import Bytes from hypothesis import given from ethereum.cancun.vm.runtime import get_valid_jump_destinations -pytestmark = pytest.mark.python_vm - class TestRuntime: @given(code=...) diff --git a/cairo/tests/ethereum/cancun/vm/test_stack.py b/cairo/tests/ethereum/cancun/vm/test_stack.py index bef281b4..3a9bd456 100644 --- a/cairo/tests/ethereum/cancun/vm/test_stack.py +++ b/cairo/tests/ethereum/cancun/vm/test_stack.py @@ -7,8 +7,6 @@ from ethereum.cancun.vm.exceptions import StackOverflowError, StackUnderflowError from ethereum.cancun.vm.stack import pop, push -pytestmark = pytest.mark.python_vm - class TestStack: def test_pop_underflow(self, cairo_run): diff --git a/cairo/tests/ethereum/crypto/test_hash.py b/cairo/tests/ethereum/crypto/test_hash.py index c954bede..59eac059 100644 --- a/cairo/tests/ethereum/crypto/test_hash.py +++ b/cairo/tests/ethereum/crypto/test_hash.py @@ -1,11 +1,8 @@ -import pytest from ethereum_types.bytes import Bytes from hypothesis import assume, given from ethereum.crypto.hash import keccak256 -pytestmark = pytest.mark.python_vm - class TestHash: diff --git a/cairo/tests/ethereum/test_rlp.py b/cairo/tests/ethereum/test_rlp.py index fa9a0c62..7099ba8a 100644 --- a/cairo/tests/ethereum/test_rlp.py +++ b/cairo/tests/ethereum/test_rlp.py @@ -23,8 +23,6 @@ ) from tests.utils.errors import cairo_error -pytestmark = pytest.mark.python_vm - class TestRlp: class TestEncode: diff --git a/cairo/tests/src/test_stack.py b/cairo/tests/src/test_stack.py index 73aa5b22..f8a266fa 100644 --- a/cairo/tests/src/test_stack.py +++ b/cairo/tests/src/test_stack.py @@ -1,8 +1,3 @@ -import pytest - -pytestmark = pytest.mark.python_vm - - class TestStack: class TestPeek: def test_should_return_stack_at_given_index__when_value_is_0(self, cairo_run): diff --git a/cairo/tests/test_serde.py b/cairo/tests/test_serde.py index 9354169d..99d89ed9 100644 --- a/cairo/tests/test_serde.py +++ b/cairo/tests/test_serde.py @@ -166,9 +166,6 @@ def single_evm_parent(b: Union[Message, Evm]) -> bool: return True -pytestmark = pytest.mark.python_vm - - class TestSerde: @given(b=...) # 20 examples per type diff --git a/cairo/tests/utils/args_gen.py b/cairo/tests/utils/args_gen.py index 3ebc78d7..c7957104 100644 --- a/cairo/tests/utils/args_gen.py +++ b/cairo/tests/utils/args_gen.py @@ -74,6 +74,7 @@ ) from cairo_addons.vm import DictTracker as RustDictTracker +from cairo_addons.vm import MemorySegmentManager as RustMemorySegmentManager from cairo_addons.vm import Relocatable as RustRelocatable from ethereum_types.bytes import ( Bytes, @@ -100,6 +101,7 @@ from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.vm.crypto import poseidon_hash_many +from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager from starkware.cairo.lang.vm.relocatable import RelocatableValue from ethereum.cancun.blocks import Header, Log, Receipt, Withdrawal @@ -374,7 +376,7 @@ def gen_arg(dict_manager, segments): def _gen_arg( dict_manager, - segments, + segments: Union[MemorySegmentManager, RustMemorySegmentManager], arg_type: Type, arg: Any, annotations: Optional[Any] = None, @@ -545,19 +547,15 @@ def _gen_arg( data=data, current_ptr=current_ptr ) else: - dict_manager.insert( - dict_ptr.segment_index, - RustDictTracker( - keys=list(data.keys()), - values=list(data.values()), - current_ptr=current_ptr, - default_value=( - data.default_factory() - if isinstance(data, defaultdict) - else None - ), - ), + default_value = ( + data.default_factory() if isinstance(data, defaultdict) else None + ) + dict_manager.trackers[dict_ptr.segment_index] = RustDictTracker( + data=data, + current_ptr=current_ptr, + default_value=default_value, ) + base = segments.add() segments.load_data(base, [dict_ptr, current_ptr]) return base @@ -589,11 +587,17 @@ def _gen_arg( if arg_type_origin is Trie: # In case of a Trie, we need the dict to be a defaultdict with the trie.default as the default value. - dict_ptr = segments.memory[data[2]] - dict_manager.trackers[dict_ptr.segment_index].data = defaultdict( - lambda: data[1], dict_manager.trackers[dict_ptr.segment_index].data - ) - + dict_ptr = segments.memory.get(data[2]) + if isinstance(dict_manager, DictManager): + dict_manager.trackers[dict_ptr.segment_index].data = defaultdict( + lambda: data[1], dict_manager.trackers[dict_ptr.segment_index].data + ) + else: + dict_manager.trackers[dict_ptr.segment_index] = RustDictTracker( + data=dict_manager.trackers[dict_ptr.segment_index].data, + current_ptr=dict_ptr, + default_value=data[1], + ) return struct_ptr if arg_type in (U256, Hash32, Bytes32, Bytes256): diff --git a/cairo/tests/utils/serde.py b/cairo/tests/utils/serde.py index 784a59fb..efcacffb 100644 --- a/cairo/tests/utils/serde.py +++ b/cairo/tests/utils/serde.py @@ -105,9 +105,7 @@ def __init__( cairo_file=None, ): self.segments = segments - self.memory = ( - segments.memory if isinstance(segments, MemorySegmentManager) else segments - ) + self.memory = segments.memory self.program = program self.dict_manager = dict_manager self.cairo_file = cairo_file or Path() diff --git a/crates/cairo-addons/src/vm/dict_manager.rs b/crates/cairo-addons/src/vm/dict_manager.rs index 9b98f7fa..736bda18 100644 --- a/crates/cairo-addons/src/vm/dict_manager.rs +++ b/crates/cairo-addons/src/vm/dict_manager.rs @@ -1,14 +1,84 @@ use cairo_vm::{ hint_processor::builtin_hint_processor::dict_manager::{ - DictManager as RustDictManager, DictTracker, + DictKey as RustDictKey, DictManager as RustDictManager, DictTracker, }, types::relocatable::MaybeRelocatable, }; -use pyo3::prelude::*; +use pyo3::{prelude::*, types::PyTuple}; use std::{cell::RefCell, collections::HashMap, rc::Rc}; use super::{maybe_relocatable::PyMaybeRelocatable, relocatable::PyRelocatable}; +#[derive(FromPyObject, Eq, PartialEq, Hash)] +pub enum PyDictKey { + #[pyo3(transparent)] + Simple(PyMaybeRelocatable), + #[pyo3(transparent)] + Compound(Vec), +} + +impl IntoPy for PyDictKey { + fn into_py(self, py: Python<'_>) -> PyObject { + match self { + PyDictKey::Simple(val) => val.into_py(py), + PyDictKey::Compound(vals) => { + // Convert Vec to tuple + let elements: Vec = vals.into_iter().map(|v| v.into_py(py)).collect(); + PyTuple::new_bound(py, elements).into() + } + } + } +} + +impl From for RustDictKey { + fn from(value: PyDictKey) -> Self { + match value { + PyDictKey::Simple(val) => RustDictKey::Simple(val.into()), + PyDictKey::Compound(vals) => { + RustDictKey::Compound(vals.into_iter().map(|v| v.into()).collect()) + } + } + } +} + +impl From for PyDictKey { + fn from(value: RustDictKey) -> Self { + match value { + RustDictKey::Simple(val) => PyDictKey::Simple(val.into()), + RustDictKey::Compound(vals) => { + PyDictKey::Compound(vals.into_iter().map(|v| v.into()).collect()) + } + } + } +} + +/// Object returned by DictManager.trackers enabling access to the trackers by index and mutating +/// the trackers with manager.trackers[index] = tracker +#[pyclass(name = "TrackerMapping", unsendable)] +pub struct PyTrackerMapping { + inner: Rc>, +} + +#[pymethods] +impl PyTrackerMapping { + fn __getitem__(&self, key: isize) -> PyResult { + self.inner + .borrow() + .trackers + .get(&key) + .cloned() + .map(|tracker| PyDictTracker { inner: tracker }) + .ok_or_else(|| { + PyErr::new::(format!("Key {} not found", key)) + }) + } + + fn __setitem__(&mut self, key: isize, value: PyDictTracker) -> PyResult<()> { + self.inner.borrow_mut().trackers.insert(key, value.inner); + Ok(()) + } +} + #[pyclass(name = "DictManager", unsendable)] pub struct PyDictManager { pub inner: Rc>, @@ -16,6 +86,16 @@ pub struct PyDictManager { #[pymethods] impl PyDictManager { + #[new] + fn new() -> Self { + Self { inner: Rc::new(RefCell::new(RustDictManager::new())) } + } + + #[getter] + fn trackers(&self) -> PyResult { + Ok(PyTrackerMapping { inner: self.inner.clone() }) + } + fn insert(&mut self, segment_index: isize, value: &PyDictTracker) -> PyResult<()> { if self.inner.borrow().trackers.contains_key(&segment_index) { return Err(PyErr::new::( @@ -26,11 +106,7 @@ impl PyDictManager { Ok(()) } - fn get_value( - &self, - segment_index: isize, - key: PyMaybeRelocatable, - ) -> PyResult { + fn get_value(&self, segment_index: isize, key: PyDictKey) -> PyResult { let value = self .inner .borrow_mut() @@ -47,24 +123,22 @@ impl PyDictManager { } #[pyclass(name = "DictTracker")] +#[derive(Clone)] pub struct PyDictTracker { inner: DictTracker, } #[pymethods] impl PyDictTracker { - // Note: This is a temporary implementation, need to understand why HashMap is not working #[new] - #[pyo3(signature = (keys, values, current_ptr, default_value=None))] + #[pyo3(signature = (data, current_ptr, default_value=None))] fn new( - keys: Vec, - values: Vec, + data: HashMap, current_ptr: PyRelocatable, default_value: Option, ) -> PyResult { - let data: HashMap = - keys.into_iter().zip(values).map(|(k, v)| (k.into(), v.into())).collect(); + let data: HashMap = + data.into_iter().map(|(k, v)| (k.into(), v.into())).collect(); if let Some(default_value) = default_value { let default_value = default_value.into(); @@ -75,4 +149,33 @@ impl PyDictTracker { Ok(Self { inner: DictTracker::new_with_initial(current_ptr.inner, data) }) } } + + #[getter] + fn current_ptr(&self) -> PyRelocatable { + PyRelocatable { inner: self.inner.current_ptr } + } + + #[getter] + fn data(&self) -> HashMap { + self.inner + .get_dictionary_ref() + .iter() + .map(|(k, v)| (PyDictKey::from(k.clone()), PyMaybeRelocatable::from(v.clone()))) + .collect() + } + + fn __repr__(&self) -> PyResult { + let mut pairs: Vec<_> = self.inner.get_dictionary_ref().iter().collect(); + + // Sort by key + pairs.sort_by(|(k1, _), (k2, _)| k1.partial_cmp(k2).unwrap_or(std::cmp::Ordering::Equal)); + + let data_str = + pairs.into_iter().map(|(k, v)| format!("{}: {}", k, v)).collect::>().join(", "); + + Ok(format!( + "DictTracker(data={{{}}}, current_ptr=Relocatable(segment_index={}, offset={}))", + data_str, self.inner.current_ptr.segment_index, self.inner.current_ptr.offset + )) + } } diff --git a/crates/cairo-addons/src/vm/felt.rs b/crates/cairo-addons/src/vm/felt.rs index d02db6e6..a86f0b59 100644 --- a/crates/cairo-addons/src/vm/felt.rs +++ b/crates/cairo-addons/src/vm/felt.rs @@ -16,7 +16,7 @@ impl Felt252Input { } #[pyclass(name = "Felt")] -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash)] pub struct PyFelt { pub(crate) inner: Felt252, } diff --git a/crates/cairo-addons/src/vm/maybe_relocatable.rs b/crates/cairo-addons/src/vm/maybe_relocatable.rs index dc8daab5..bc14304c 100644 --- a/crates/cairo-addons/src/vm/maybe_relocatable.rs +++ b/crates/cairo-addons/src/vm/maybe_relocatable.rs @@ -3,7 +3,7 @@ use cairo_vm::types::relocatable::MaybeRelocatable as RustMaybeRelocatable; use num_bigint::BigUint; use pyo3::{FromPyObject, IntoPy, PyObject, Python}; -#[derive(FromPyObject)] +#[derive(FromPyObject, Eq, PartialEq, Hash)] pub enum PyMaybeRelocatable { #[pyo3(transparent)] Felt(PyFelt), diff --git a/crates/cairo-addons/src/vm/memory_segments.rs b/crates/cairo-addons/src/vm/memory_segments.rs index 498687a9..eff93c01 100644 --- a/crates/cairo-addons/src/vm/memory_segments.rs +++ b/crates/cairo-addons/src/vm/memory_segments.rs @@ -10,8 +10,26 @@ pub struct PyMemorySegmentManager { pub(crate) runner: *mut RustCairoRunner, } +/// Enables syntax `segments.memory.` +#[pyclass(name = "MemoryWrapper", unsendable)] +pub struct PyMemoryWrapper { + pub(crate) runner: *mut RustCairoRunner, +} + +#[pymethods] +impl PyMemoryWrapper { + fn get(&self, key: PyRelocatable) -> Option { + unsafe { (*self.runner).vm.get_maybe(&key.inner).map(PyMaybeRelocatable::from) } + } +} + #[pymethods] impl PyMemorySegmentManager { + #[getter] + fn memory(&self) -> PyMemoryWrapper { + PyMemoryWrapper { runner: self.runner } + } + fn add(&mut self) -> PyRelocatable { unsafe { (*self.runner).vm.segments.add().into() } } @@ -49,8 +67,4 @@ impl PyMemorySegmentManager { fn compute_effective_sizes(&mut self) -> Vec { unsafe { (*self.runner).vm.segments.compute_effective_sizes().clone() } } - - fn get(&self, key: PyRelocatable) -> Option { - unsafe { (*self.runner).vm.get_maybe(&key.inner).map(PyMaybeRelocatable::from) } - } } diff --git a/crates/cairo-addons/src/vm/relocatable.rs b/crates/cairo-addons/src/vm/relocatable.rs index f2d5131d..04e6e04b 100644 --- a/crates/cairo-addons/src/vm/relocatable.rs +++ b/crates/cairo-addons/src/vm/relocatable.rs @@ -6,7 +6,7 @@ use pyo3::prelude::*; use super::maybe_relocatable::PyMaybeRelocatable; #[pyclass(name = "Relocatable")] -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash)] pub struct PyRelocatable { pub(crate) inner: RustRelocatable, } diff --git a/python/cairo-addons/tests/test_dict_manager.py b/python/cairo-addons/tests/test_dict_manager.py index 326f1d9a..7dcdb4cd 100644 --- a/python/cairo-addons/tests/test_dict_manager.py +++ b/python/cairo-addons/tests/test_dict_manager.py @@ -1,5 +1,9 @@ import pytest -from cairo_addons.vm import CairoRunner, DictTracker +from cairo_addons.vm import CairoRunner +from cairo_addons.vm import DictManager as RustDictManager +from cairo_addons.vm import DictTracker as RustDictTracker +from cairo_addons.vm import Relocatable as RustRelocatable +from starkware.cairo.common.dict import DictManager, DictTracker @pytest.fixture @@ -17,9 +21,8 @@ def test_should_insert_dict(self, runner): current_ptr = dict_ptr + len(initial_data) runner.dict_manager.insert( dict_ptr.segment_index, - DictTracker( - keys=list(data.keys()), - values=list(data.values()), + RustDictTracker( + data=data, current_ptr=current_ptr, ), ) @@ -30,9 +33,7 @@ def test_should_insert_default_dict(self, runner): dict_ptr = runner.segments.add() runner.dict_manager.insert( dict_ptr.segment_index, - DictTracker( - keys=[], values=[], current_ptr=dict_ptr, default_value=0xABDE1 - ), + RustDictTracker(data={}, current_ptr=dict_ptr, default_value=0xABDE1), ) assert runner.dict_manager.get_value(dict_ptr.segment_index, 1) == 0xABDE1 @@ -40,10 +41,30 @@ def test_should_raise_existing_dict(self, runner): dict_ptr = runner.segments.add() runner.dict_manager.insert( dict_ptr.segment_index, - DictTracker(keys=[], values=[], current_ptr=dict_ptr), + RustDictTracker(data={}, current_ptr=dict_ptr), ) with pytest.raises(ValueError, match="Segment index already exists"): runner.dict_manager.insert( dict_ptr.segment_index, - DictTracker(keys=[], values=[], current_ptr=dict_ptr), + RustDictTracker(data={}, current_ptr=dict_ptr), ) + + def test_api_compatibility(self): + rust_manager = RustDictManager() + python_manager = DictManager() + data = {1: 4, 2: 5, 3: 6} + dict_ptr = RustRelocatable(segment_index=0, offset=0) + + rust_tracker = RustDictTracker( + data=data, current_ptr=dict_ptr, default_value=None + ) + python_tracker = DictTracker(data=data, current_ptr=dict_ptr) + + assert rust_tracker.current_ptr == python_tracker.current_ptr + assert rust_tracker.data == python_tracker.data + assert str(rust_tracker) == str(python_tracker) + rust_manager.trackers[dict_ptr.segment_index] = rust_tracker + python_manager.trackers[dict_ptr.segment_index] = python_tracker + assert str(rust_manager.trackers[dict_ptr.segment_index]) == str( + python_manager.trackers[dict_ptr.segment_index] + ) diff --git a/python/cairo-addons/tests/test_memory_segments.py b/python/cairo-addons/tests/test_memory_segments.py index 924e6197..cb06d7bc 100644 --- a/python/cairo-addons/tests/test_memory_segments.py +++ b/python/cairo-addons/tests/test_memory_segments.py @@ -53,3 +53,8 @@ def test_compute_effective_sizes(self, runner): assert sizes == [4] assert runner.segments.get_segment_used_size(0) == 4 assert runner.segments.get_segment_size(0) == 4 + + def test_memory_wrapper(self, runner): + ptr = runner.segments.add() + runner.segments.load_data(ptr, [Felt(1), Felt(2), Felt(3), Felt(4)]) + assert runner.segments.memory.get(ptr) == Felt(1)