Skip to content

Commit

Permalink
Add rlp.encode_sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementWalter committed Nov 13, 2024
1 parent 63cf11f commit ba7bd85
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
31 changes: 30 additions & 1 deletion cairo/ethereum/rlp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.memcpy import memcpy
from ethereum.base_types import Bytes, BytesStruct, TupleBytes, TupleBytesStruct
from ethereum.utils.numeric import is_zero
from src.utils.bytes import felt_to_bytes
from src.utils.bytes import felt_to_bytes, felt_to_bytes_little
from src.utils.array import reverse

func _encode_bytes{range_check_ptr}(dst: felt*, raw_bytes: Bytes) -> felt {
alloc_locals;
Expand Down Expand Up @@ -68,3 +69,31 @@ func get_joined_encodings{range_check_ptr}(raw_sequence: TupleBytes) -> Bytes {
let encoded_bytes = Bytes(value);
return encoded_bytes;
}

// @notice: Encodes a sequence of RLP encodable objects (`raw_sequence`) using RLP.
// @dev: The standard implementation assumes that the length fits in at most 9 bytes
// since the leading byte is Bytes([0xF7 + len(len_joined_encodings_as_be)]).
// In total, it means that the sequence starts at most at dst + 10.
// To avoid a memcpy, we start using the allocated memory at dst + 10.
func encode_sequence{range_check_ptr}(raw_sequence: TupleBytes) -> Bytes {
alloc_locals;
let (dst) = alloc();
let len = _get_joined_encodings(dst + 10, raw_sequence.value.value, raw_sequence.value.len);
let cond = is_le(len, 0x38 - 1);
if (cond != 0) {
assert [dst + 9] = 0xC0 + len;
tempvar value = new BytesStruct(dst + 9, len + 1);
let encoded_bytes = Bytes(value);
return encoded_bytes;
}

let (len_joined_encodings_as_le: felt*) = alloc();
let len_joined_encodings_as_le_len = felt_to_bytes_little(len_joined_encodings_as_le, len);
let dst = dst + 10 - len_joined_encodings_as_le_len - 1;
reverse(dst + 1, len_joined_encodings_as_le_len, len_joined_encodings_as_le);
assert [dst] = 0xF7 + len_joined_encodings_as_le_len;

tempvar value = new BytesStruct(dst, len + 1 + len_joined_encodings_as_le_len);
let encoded_bytes = Bytes(value);
return encoded_bytes;
}
9 changes: 8 additions & 1 deletion cairo/tests/ethereum/test_rlp.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ethereum.rlp import encode_bytes, get_joined_encodings
from ethereum.rlp import encode_bytes, get_joined_encodings, encode_sequence
from ethereum.base_types import Bytes, TupleBytes

func test_encode_bytes{range_check_ptr}() -> Bytes {
Expand All @@ -14,3 +14,10 @@ func test_get_joined_encodings{range_check_ptr}() -> Bytes {
let encoded_bytes = get_joined_encodings(raw_sequence);
return encoded_bytes;
}

func test_encode_sequence{range_check_ptr}() -> Bytes {
tempvar raw_sequence: TupleBytes;
%{ memory[ap - 1] = gen_arg(program_input["raw_sequence"]) %}
let encoded_bytes = encode_sequence(raw_sequence);
return encoded_bytes;
}
8 changes: 7 additions & 1 deletion cairo/tests/ethereum/test_rlp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hypothesis.strategies as st
from hypothesis import given

from ethereum.rlp import encode_bytes, get_joined_encodings
from ethereum.rlp import encode_bytes, encode_sequence, get_joined_encodings


class TestRlp:
Expand All @@ -16,3 +16,9 @@ def test_get_joined_encodings(self, cairo_run, raw_sequence):
assert get_joined_encodings(raw_sequence) == cairo_run(
"test_get_joined_encodings", raw_sequence=raw_sequence
)

@given(raw_sequence=st.tuples(st.binary()))
def test_encode_sequence(self, cairo_run, raw_sequence):
assert encode_sequence(raw_sequence) == cairo_run(
"test_encode_sequence", raw_sequence=raw_sequence
)

0 comments on commit ba7bd85

Please sign in to comment.