From 21ddb93069b2cee67c0b49967ba5f89987a8f01c Mon Sep 17 00:00:00 2001 From: Jonathan LEI Date: Mon, 21 Oct 2024 18:13:42 +0800 Subject: [PATCH] feat: implement `Encode` and `Decode` for primitive collection types (#670) --- examples/serde.rs | 18 +++- starknet-core/src/codec.rs | 183 ++++++++++++++++++++++++++++++++++++- 2 files changed, 197 insertions(+), 4 deletions(-) diff --git a/examples/serde.rs b/examples/serde.rs index 4c80d530..63f9c843 100644 --- a/examples/serde.rs +++ b/examples/serde.rs @@ -10,14 +10,16 @@ use starknet::{ struct CairoType { a: Felt, b: Option, - c: bool, + c: Vec, + d: [u8; 2], } fn main() { let instance = CairoType { a: felt!("123456789"), b: Some(100), - c: false, + c: vec![false, true], + d: [3, 4], }; let mut serialized = vec![]; @@ -25,7 +27,17 @@ fn main() { assert_eq!( serialized, - [felt!("123456789"), felt!("0"), felt!("100"), felt!("0")] + [ + felt!("123456789"), + felt!("0"), + felt!("100"), + felt!("2"), + felt!("0"), + felt!("1"), + felt!("2"), + felt!("3"), + felt!("4"), + ] ); let restored = CairoType::decode(&serialized).unwrap(); diff --git a/starknet-core/src/codec.rs b/starknet-core/src/codec.rs index 0cc60f6e..b5ab08fb 100644 --- a/starknet-core/src/codec.rs +++ b/starknet-core/src/codec.rs @@ -1,5 +1,5 @@ use alloc::{boxed::Box, fmt::Formatter, format, string::*, vec::*}; -use core::fmt::Display; +use core::{fmt::Display, mem::MaybeUninit}; use num_traits::ToPrimitive; @@ -139,6 +139,51 @@ where } } +impl Encode for Vec +where + T: Encode, +{ + fn encode(&self, writer: &mut W) -> Result<(), Error> { + writer.write(Felt::from(self.len())); + + for item in self { + item.encode(writer)?; + } + + Ok(()) + } +} + +impl Encode for [T; N] +where + T: Encode, +{ + fn encode(&self, writer: &mut W) -> Result<(), Error> { + writer.write(Felt::from(N)); + + for item in self { + item.encode(writer)?; + } + + Ok(()) + } +} + +impl Encode for [T] +where + T: Encode, +{ + fn encode(&self, writer: &mut W) -> Result<(), Error> { + writer.write(Felt::from(self.len())); + + for item in self { + item.encode(writer)?; + } + + Ok(()) + } +} + impl<'a> Decode<'a> for Felt { fn decode_iter(iter: &mut T) -> Result where @@ -263,6 +308,56 @@ where } } +impl<'a, T> Decode<'a> for Vec +where + T: Decode<'a>, +{ + fn decode_iter(iter: &mut I) -> Result + where + I: Iterator, + { + let length = iter.next().ok_or_else(Error::input_exhausted)?; + let length = length + .to_usize() + .ok_or_else(|| Error::value_out_of_range(length, "usize"))?; + + let mut result = Self::with_capacity(length); + + for _ in 0..length { + result.push(T::decode_iter(iter)?); + } + + Ok(result) + } +} + +impl<'a, T, const N: usize> Decode<'a> for [T; N] +where + T: Decode<'a> + Sized, +{ + fn decode_iter(iter: &mut I) -> Result + where + I: Iterator, + { + let length = iter.next().ok_or_else(Error::input_exhausted)?; + let length = length + .to_usize() + .ok_or_else(|| Error::value_out_of_range(length, "usize"))?; + + if length != N { + return Err(Error::length_mismatch(N, length)); + } + + let mut result: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; + + for elem in &mut result[..] { + *elem = MaybeUninit::new(T::decode_iter(iter)?); + } + + Ok(unsafe { core::mem::transmute_copy::<_, [T; N]>(&result) }) + } +} + impl Error { /// Creates an [`Error`] which indicates that the input stream has ended prematurely. pub fn input_exhausted() -> Self { @@ -273,6 +368,14 @@ impl Error { } } + /// Creates an [`Error`] which indicates that the length (likely prefix) is different from the + /// expected value. + pub fn length_mismatch(expected: usize, actual: usize) -> Self { + Self { + repr: format!("expecting length `{}` but got `{}`", expected, actual).into_boxed_str(), + } + } + /// Creates an [`Error`] which indicates that the input value is out of range. pub fn value_out_of_range(value: V, type_name: &str) -> Self where @@ -426,6 +529,54 @@ mod tests { assert_eq!(serialized, vec![Felt::from_str("1").unwrap()]); } + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_encode_vec() { + let mut serialized = Vec::::new(); + vec![Some(10u32), None].encode(&mut serialized).unwrap(); + assert_eq!( + serialized, + vec![ + Felt::from_str("2").unwrap(), + Felt::from_str("0").unwrap(), + Felt::from_str("10").unwrap(), + Felt::from_str("1").unwrap() + ] + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_encode_array() { + let mut serialized = Vec::::new(); + <[Option; 2]>::encode(&[Some(10u32), None], &mut serialized).unwrap(); + assert_eq!( + serialized, + vec![ + Felt::from_str("2").unwrap(), + Felt::from_str("0").unwrap(), + Felt::from_str("10").unwrap(), + Felt::from_str("1").unwrap() + ] + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_encode_slice() { + let mut serialized = Vec::::new(); + <[Option]>::encode(&[Some(10u32), None], &mut serialized).unwrap(); + assert_eq!( + serialized, + vec![ + Felt::from_str("2").unwrap(), + Felt::from_str("0").unwrap(), + Felt::from_str("10").unwrap(), + Felt::from_str("1").unwrap() + ] + ); + } + #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] fn test_derive_encode_struct_named() { @@ -639,6 +790,36 @@ mod tests { ); } + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_decode_vec() { + assert_eq!( + vec![Some(10u32), None], + Vec::>::decode(&[ + Felt::from_str("2").unwrap(), + Felt::from_str("0").unwrap(), + Felt::from_str("10").unwrap(), + Felt::from_str("1").unwrap() + ]) + .unwrap() + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_decode_array() { + assert_eq!( + [Some(10u32), None], + <[Option; 2]>::decode(&[ + Felt::from_str("2").unwrap(), + Felt::from_str("0").unwrap(), + Felt::from_str("10").unwrap(), + Felt::from_str("1").unwrap() + ]) + .unwrap() + ); + } + #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] fn test_derive_decode_struct_named() {