Skip to content

Commit

Permalink
Add rlp.decode_to_bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementWalter committed Nov 14, 2024
1 parent ee6512e commit 74b9b03
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 49 deletions.
53 changes: 52 additions & 1 deletion cairo/ethereum/rlp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ from ethereum.base_types import Bytes, BytesStruct, TupleBytes, TupleBytesStruct
from ethereum.crypto.hash import keccak256, Hash32
from ethereum.utils.numeric import is_zero
from src.utils.array import reverse
from src.utils.bytes import felt_to_bytes, felt_to_bytes_little
from src.utils.bytes import felt_to_bytes, felt_to_bytes_little, bytes_to_felt
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin, KeccakBuiltin
from starkware.cairo.common.math_cmp import is_le, is_not_zero
from starkware.cairo.common.math import assert_not_zero
from starkware.cairo.common.memcpy import memcpy

func _encode_bytes{range_check_ptr}(dst: felt*, raw_bytes: Bytes) -> felt {
Expand Down Expand Up @@ -106,3 +107,53 @@ func rlp_hash{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakB
let encoded_bytes = encode_bytes(raw_bytes);
return keccak256(encoded_bytes);
}

// @dev The reference function doesn't handle the case where encoded_bytes.len == 0
func decode_to_bytes{range_check_ptr}(encoded_bytes: Bytes) -> Bytes {
alloc_locals;
assert_not_zero(encoded_bytes.value.len);
assert [range_check_ptr] = encoded_bytes.value.len;
let range_check_ptr = range_check_ptr + 1;

let cond = is_le(encoded_bytes.value.data[0], 0x80 - 1);
if (encoded_bytes.value.len == 1 and cond != 0) {
return encoded_bytes;
}

let cond = is_le(encoded_bytes.value.data[0], 0xB7);
if (cond != 0) {
let len_raw_data = encoded_bytes.value.data[0] - 0x80;
assert [range_check_ptr] = len_raw_data;
let range_check_ptr = range_check_ptr + 1;
assert [range_check_ptr] = encoded_bytes.value.len - len_raw_data;
let range_check_ptr = range_check_ptr + 1;
let raw_data = encoded_bytes.value.data + 1;
if (len_raw_data == 1) {
assert [range_check_ptr] = raw_data[0] - 0x80;
tempvar range_check_ptr = range_check_ptr + 1;
} else {
tempvar range_check_ptr = range_check_ptr;
}
let range_check_ptr = [ap - 1];
tempvar value = new BytesStruct(raw_data, len_raw_data);
let decoded_bytes = Bytes(value);
return decoded_bytes;
}

let decoded_data_start_idx = 1 + encoded_bytes.value.data[0] - 0xB7;
assert [range_check_ptr] = encoded_bytes.value.len - decoded_data_start_idx;
let range_check_ptr = range_check_ptr + 1;
assert_not_zero(encoded_bytes.value.data[1]);
let len_decoded_data = bytes_to_felt(decoded_data_start_idx - 1, encoded_bytes.value.data + 1);
assert [range_check_ptr] = len_decoded_data - 0x38;
let range_check_ptr = range_check_ptr + 1;

let decoded_data_end_idx = decoded_data_start_idx + len_decoded_data;
assert [range_check_ptr] = encoded_bytes.value.len - decoded_data_end_idx;

let raw_data = encoded_bytes.value.data + decoded_data_start_idx;
tempvar value = new BytesStruct(raw_data, decoded_data_end_idx - decoded_data_start_idx);
let decoded_bytes = Bytes(value);

return decoded_bytes;
}
30 changes: 30 additions & 0 deletions cairo/src/utils/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,36 @@ func felt_to_bytes{range_check_ptr}(dst: felt*, value: felt) -> felt {
return bytes_len;
}

// @notice Loads a sequence of bytes into a single felt in big-endian.
// @param len: number of bytes.
// @param ptr: pointer to bytes array.
// @return: packed felt.
func bytes_to_felt(len: felt, ptr: felt*) -> felt {
if (len == 0) {
return 0;
}
tempvar current = 0;

// len, ptr, ?, ?, current
// ?, ? are intermediate steps created by the compiler to unfold the
// complex expression.
loop:
let len = [ap - 5];
let ptr = cast([ap - 4], felt*);
let current = [ap - 1];

tempvar len = len - 1;
tempvar ptr = ptr + 1;
tempvar current = current * 256 + [ptr - 1];

static_assert len == [ap - 5];
static_assert ptr == [ap - 4];
static_assert current == [ap - 1];
jmp loop if len != 0;

return current;
}

// @notice Split a felt into an array of 20 bytes, big endian
// @dev Truncate the high 12 bytes
func felt_to_bytes20{range_check_ptr}(dst: felt*, value: felt) {
Expand Down
15 changes: 14 additions & 1 deletion cairo/tests/ethereum/test_rlp.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from ethereum.rlp import encode_bytes, get_joined_encodings, encode_sequence, rlp_hash
from ethereum.rlp import (
encode_bytes,
get_joined_encodings,
encode_sequence,
rlp_hash,
decode_to_bytes,
)
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin, KeccakBuiltin
from ethereum.base_types import Bytes, TupleBytes
from ethereum.crypto.hash import Hash32
Expand Down Expand Up @@ -31,3 +37,10 @@ func test_rlp_hash{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: Ke
let hash = rlp_hash(raw_bytes);
return hash;
}

func test_decode_to_bytes{range_check_ptr}() -> Bytes {
tempvar encoded_bytes: Bytes;
%{ memory[ap - 1] = gen_arg(program_input["encoded_bytes"]) %}
let decoded_bytes = decode_to_bytes(encoded_bytes);
return decoded_bytes;
}
32 changes: 31 additions & 1 deletion cairo/tests/ethereum/test_rlp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import hypothesis.strategies as st
from hypothesis import given

from ethereum.rlp import encode_bytes, encode_sequence, get_joined_encodings, rlp_hash
from ethereum.rlp import (
decode_to_bytes,
encode_bytes,
encode_sequence,
get_joined_encodings,
rlp_hash,
)


class TestRlp:
Expand All @@ -26,3 +32,27 @@ def test_encode_sequence(self, cairo_run, raw_sequence):
@given(raw_bytes=st.binary())
def test_rlp_hash(self, cairo_run, raw_bytes):
assert rlp_hash(raw_bytes) == cairo_run("test_rlp_hash", raw_bytes=raw_bytes)

@given(raw_bytes=st.binary())
def test_decode_to_bytes(self, cairo_run, raw_bytes):
encoded_bytes = encode_bytes(raw_bytes)
assert decode_to_bytes(encoded_bytes) == cairo_run(
"test_decode_to_bytes", encoded_bytes=encoded_bytes
)

@given(encoded_bytes=st.binary())
def test_decode_to_bytes_should_raise(self, cairo_run, encoded_bytes):
"""
The cairo implementation of decode_to_bytes raises more often than the
eth-rlp implementation because this latter accepts negative
See https://github.com/ethereum/execution-specs/issues/1035
"""
decoded_bytes = None
try:
decoded_bytes = cairo_run(
"test_decode_to_bytes", encoded_bytes=encoded_bytes
)
except Exception:
pass
if decoded_bytes is not None:
assert decoded_bytes == decode_to_bytes(encoded_bytes)
6 changes: 3 additions & 3 deletions cairo/tests/src/precompiles/test_ec_recover.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
from ethereum.base_types import U256
from ethereum.crypto.elliptic_curve import SECP256K1N, secp256k1_recover
from ethereum.crypto.hash import Hash32, keccak256
from hypothesis import given
from hypothesis import strategies as st

from ethereum.base_types import U256
from ethereum.crypto.elliptic_curve import SECP256K1N, secp256k1_recover
from ethereum.crypto.hash import Hash32, keccak256
from ethereum.utils.byte import left_pad_zero_bytes
from tests.utils.helpers import ec_sign, generate_random_private_key

Expand Down
12 changes: 12 additions & 0 deletions cairo/tests/src/utils/test_bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from src.utils.bytes import (
uint256_to_bytes,
uint256_to_bytes32,
bytes_to_bytes8_little_endian,
bytes_to_felt,
)

func test__felt_to_ascii{range_check_ptr}(output_ptr: felt*) {
Expand Down Expand Up @@ -101,3 +102,14 @@ func test__bytes_to_bytes8_little_endian{range_check_ptr}() -> felt* {

return bytes8;
}

func test__bytes_to_felt() -> felt {
tempvar len;
let (ptr) = alloc();
%{
ids.len = len(program_input["data"])
segments.write_arg(ids.ptr, program_input["data"])
%}
let res = bytes_to_felt(len, ptr);
return res;
}
7 changes: 7 additions & 0 deletions cairo/tests/src/utils/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,10 @@ def test_should_return_bytes8(self, cairo_run, data):
output = cairo_run("test__bytes_to_bytes8_little_endian", bytes=data)

assert bytes8_little_endian == output

class TestBytesToFelt:

@given(data=binary(min_size=0, max_size=35))
def test_should_convert_bytes_to_felt_with_overflow(self, cairo_run, data):
output = cairo_run("test__bytes_to_felt", data=list(data))
assert output == int.from_bytes(data, byteorder="big") % DEFAULT_PRIME
11 changes: 0 additions & 11 deletions cairo/tests/src/utils/test_utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,3 @@ func test__split_word_little{range_check_ptr}() -> felt* {
Helpers.split_word_little(value, len, dst);
return dst;
}

func test__bytes_to_felt() -> felt {
tempvar len;
let (ptr) = alloc();
%{
ids.len = len(program_input["data"])
segments.write_arg(ids.ptr, program_input["data"])
%}
let res = Helpers.bytes_to_felt(len, ptr);
return res;
}
9 changes: 0 additions & 9 deletions cairo/tests/src/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME

from tests.utils.errors import cairo_error
from tests.utils.helpers import pack_calldata
Expand Down Expand Up @@ -227,11 +226,3 @@ def test_should_raise_when_len_ge_32_split_word_little(
):
with cairo_error("len must be < 32"):
cairo_run("test__split_word_little", value=value, length=length)


class TestBytesToFelt:

@given(data=st.binary(min_size=0, max_size=35))
def test_should_convert_bytes_to_felt_with_overflow(self, cairo_run, data):
output = cairo_run("test__bytes_to_felt", data=list(data))
assert output == int.from_bytes(data, byteorder="big") % DEFAULT_PRIME
14 changes: 7 additions & 7 deletions cairo/tests/test_serde.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import pytest
from hypothesis import given
from hypothesis import strategies as st
from starkware.cairo.common.dict import DictManager
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME
from starkware.cairo.lang.vm.memory_dict import MemoryDict
from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager

from ethereum.base_types import (
U64,
U256,
Expand All @@ -11,13 +18,6 @@
Uint,
)
from ethereum.cancun.fork_types import Address, Bloom, Root, VersionedHash
from hypothesis import given
from hypothesis import strategies as st
from starkware.cairo.common.dict import DictManager
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME
from starkware.cairo.lang.vm.memory_dict import MemoryDict
from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager

from tests.utils.hints import gen_arg as _gen_arg
from tests.utils.serde import Serde
from tests.utils.serde import get_cairo_type as _get_cairo_type
Expand Down
6 changes: 3 additions & 3 deletions cairo/tests/utils/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from typing import Dict, Iterable, Tuple, Union
from unittest.mock import patch

from ethereum.base_types import U256, Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256
from ethereum.cancun.blocks import Header, Log, Withdrawal
from ethereum.crypto.hash import Hash32
from starkware.cairo.common.dict import DictTracker
from starkware.cairo.lang.compiler.program import CairoHint
from starkware.cairo.lang.vm.relocatable import MaybeRelocatable

from ethereum.base_types import U256, Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256
from ethereum.cancun.blocks import Header, Log, Withdrawal
from ethereum.crypto.hash import Hash32
from src.utils.uint256 import int_to_uint256
from tests.utils.helpers import flatten

Expand Down
23 changes: 12 additions & 11 deletions cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from typing import Optional

from eth_utils.address import to_checksum_address
from starkware.cairo.lang.compiler.ast.cairo_types import (
TypeFelt,
TypePointer,
TypeStruct,
TypeTuple,
)
from starkware.cairo.lang.compiler.identifier_definition import (
StructDefinition,
TypeDefinition,
)
from starkware.cairo.lang.compiler.identifier_manager import MissingIdentifierError

from ethereum.base_types import (
U64,
U256,
Expand All @@ -13,17 +25,6 @@
)
from ethereum.cancun.blocks import Block, Header, Log, Receipt, Withdrawal
from ethereum.cancun.fork_types import Account
from starkware.cairo.lang.compiler.ast.cairo_types import (
TypeFelt,
TypePointer,
TypeStruct,
TypeTuple,
)
from starkware.cairo.lang.compiler.identifier_definition import (
StructDefinition,
TypeDefinition,
)
from starkware.cairo.lang.compiler.identifier_manager import MissingIdentifierError


def get_cairo_type(program, name):
Expand Down
5 changes: 3 additions & 2 deletions cairo/tests/utils/strategies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from hypothesis import strategies as st
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME

from ethereum.base_types import U64, U256, Bytes8, Bytes32, Uint
from ethereum.cancun.blocks import Block, Header, Log, Receipt, Withdrawal
from ethereum.cancun.fork_types import Account, Address, Bloom, Root
from ethereum.cancun.trie import Trie
from ethereum.crypto.hash import Hash32
from hypothesis import strategies as st
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME

# Base types
uint20 = st.integers(min_value=0, max_value=2**20 - 1)
Expand Down

0 comments on commit 74b9b03

Please sign in to comment.