From 1da310b58c3bcfa8f946da1d14bd49ee7c11a801 Mon Sep 17 00:00:00 2001 From: Levi <31335633+levifeldman@users.noreply.github.com> Date: Fri, 12 Jan 2024 16:49:12 -0500 Subject: [PATCH] feat: Serde Serialize and Deserialize traits for the RbTree in the ic-certified-map crate. (#399) * This change adds the serde Serialize and Deserialize traits to the RbTree. * Manual implementation of Serialize and Deserialize for the RbTree. * replace 'static lifetime bounds with a custom lifetime 't. * Added bincode serialization in the serde test. * CandidType for the RbTree. * update CHANGELOG with CandidType. --- library/ic-certified-map/CHANGELOG.md | 3 + library/ic-certified-map/Cargo.toml | 3 +- library/ic-certified-map/src/rbtree.rs | 166 +++++++++++++++----- library/ic-certified-map/src/rbtree/test.rs | 23 +++ 4 files changed, 154 insertions(+), 41 deletions(-) diff --git a/library/ic-certified-map/CHANGELOG.md b/library/ic-certified-map/CHANGELOG.md index e47783d12..fab88d089 100644 --- a/library/ic-certified-map/CHANGELOG.md +++ b/library/ic-certified-map/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added +- Implement CandidType, Serialize, and Deserialize for the RbTree. + ## [0.4.0] - 2023-07-13 ### Changed diff --git a/library/ic-certified-map/Cargo.toml b/library/ic-certified-map/Cargo.toml index c72a6e3bf..331b29c0b 100644 --- a/library/ic-certified-map/Cargo.toml +++ b/library/ic-certified-map/Cargo.toml @@ -22,9 +22,10 @@ include = ["src", "Cargo.toml", "CHANGELOG.md", "LICENSE", "README.md"] serde.workspace = true serde_bytes.workspace = true sha2.workspace = true +candid.workspace = true [dev-dependencies] hex.workspace = true serde_cbor = "0.11" ic-cdk.workspace = true -candid.workspace = true +bincode = "1.3.3" diff --git a/library/ic-certified-map/src/rbtree.rs b/library/ic-certified-map/src/rbtree.rs index 6878c6efb..cccbd9668 100644 --- a/library/ic-certified-map/src/rbtree.rs +++ b/library/ic-certified-map/src/rbtree.rs @@ -55,7 +55,7 @@ impl AsHashTree for Hash { } } -impl, V: AsHashTree + 'static> AsHashTree for RbTree { +impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> AsHashTree for RbTree { fn root_hash(&self) -> Hash { match self.root.as_ref() { None => Empty.reconstruct(), @@ -102,7 +102,7 @@ struct Node { subtree_hash: Hash, } -impl, V: AsHashTree + 'static> Node { +impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> Node { fn new(key: K, value: V) -> Box> { let value_hash = value.root_hash(); let data_hash = labeled_hash(key.as_ref(), &value_hash); @@ -274,47 +274,47 @@ pub struct RbTree { root: NodeRef, } -impl PartialEq for RbTree +impl<'t, K, V> PartialEq for RbTree where - K: 'static + AsRef<[u8]> + PartialEq, - V: 'static + AsHashTree + PartialEq, + K: 't + AsRef<[u8]> + PartialEq, + V: 't + AsHashTree + PartialEq, { fn eq(&self, other: &Self) -> bool { self.iter().eq(other.iter()) } } -impl Eq for RbTree +impl<'t, K, V> Eq for RbTree where - K: 'static + AsRef<[u8]> + Eq, - V: 'static + AsHashTree + Eq, + K: 't + AsRef<[u8]> + Eq, + V: 't + AsHashTree + Eq, { } -impl PartialOrd for RbTree +impl<'t, K, V> PartialOrd for RbTree where - K: 'static + AsRef<[u8]> + PartialOrd, - V: 'static + AsHashTree + PartialOrd, + K: 't + AsRef<[u8]> + PartialOrd, + V: 't + AsHashTree + PartialOrd, { fn partial_cmp(&self, other: &Self) -> Option { self.iter().partial_cmp(other.iter()) } } -impl Ord for RbTree +impl<'t, K, V> Ord for RbTree where - K: 'static + AsRef<[u8]> + Ord, - V: 'static + AsHashTree + Ord, + K: 't + AsRef<[u8]> + Ord, + V: 't + AsHashTree + Ord, { fn cmp(&self, other: &Self) -> Ordering { self.iter().cmp(other.iter()) } } -impl std::iter::FromIterator<(K, V)> for RbTree +impl<'t, K, V> std::iter::FromIterator<(K, V)> for RbTree where - K: 'static + AsRef<[u8]>, - V: 'static + AsHashTree, + K: 't + AsRef<[u8]>, + V: 't + AsHashTree, { fn from_iter(iter: T) -> Self where @@ -328,10 +328,10 @@ where } } -impl std::fmt::Debug for RbTree +impl<'t, K, V> std::fmt::Debug for RbTree where - K: 'static + AsRef<[u8]> + std::fmt::Debug, - V: 'static + AsHashTree + std::fmt::Debug, + K: 't + AsRef<[u8]> + std::fmt::Debug, + V: 't + AsHashTree + std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[")?; @@ -359,7 +359,7 @@ impl RbTree { } } -impl, V: AsHashTree + 'static> RbTree { +impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> RbTree { /// Looks up the key in the map and returns the associated value, if there is one. pub fn get(&self, key: &[u8]) -> Option<&V> { let mut root = self.root.as_ref(); @@ -375,7 +375,7 @@ impl, V: AsHashTree + 'static> RbTree { /// Updates the value corresponding to the specified key. pub fn modify(&mut self, key: &[u8], f: impl FnOnce(&mut V)) { - fn go, V: AsHashTree + 'static>( + fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( h: &mut NodeRef, k: &[u8], f: impl FnOnce(&mut V), @@ -506,7 +506,7 @@ impl, V: AsHashTree + 'static> RbTree { lo: KeyBound<'a>, f: fn(&'a Node) -> HashTree<'a>, ) -> HashTree<'a> { - fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>( + fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( n: &'a NodeRef, lo: KeyBound<'a>, f: fn(&'a Node) -> HashTree<'a>, @@ -543,7 +543,7 @@ impl, V: AsHashTree + 'static> RbTree { hi: KeyBound<'a>, f: fn(&'a Node) -> HashTree<'a>, ) -> HashTree<'a> { - fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>( + fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( n: &'a NodeRef, hi: KeyBound<'a>, f: fn(&'a Node) -> HashTree<'a>, @@ -587,7 +587,7 @@ impl, V: AsHashTree + 'static> RbTree { lo.as_ref(), hi.as_ref() ); - fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>( + fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( n: &'a NodeRef, lo: KeyBound<'a>, hi: KeyBound<'a>, @@ -645,7 +645,7 @@ impl, V: AsHashTree + 'static> RbTree { } fn lower_bound(&self, key: &[u8]) -> Option> { - fn go<'a, K: 'static + AsRef<[u8]>, V>( + fn go<'a, 't, K: 't + AsRef<[u8]>, V>( n: &'a NodeRef, key: &[u8], ) -> Option> { @@ -662,7 +662,7 @@ impl, V: AsHashTree + 'static> RbTree { } fn upper_bound(&self, key: &[u8]) -> Option> { - fn go<'a, K: 'static + AsRef<[u8]>, V>( + fn go<'a, 't, K: 't + AsRef<[u8]>, V>( n: &'a NodeRef, key: &[u8], ) -> Option> { @@ -685,7 +685,7 @@ impl, V: AsHashTree + 'static> RbTree { } &x[0..p.len()] == p } - fn go<'a, K: 'static + AsRef<[u8]>, V>( + fn go<'a, 't, K: 't + AsRef<[u8]>, V>( n: &'a NodeRef, prefix: &[u8], ) -> Option> { @@ -706,7 +706,7 @@ impl, V: AsHashTree + 'static> RbTree { key: &[u8], f: impl FnOnce(&'a V) -> HashTree<'a>, ) -> Option> { - fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>( + fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( n: &'a NodeRef, key: &[u8], f: impl FnOnce(&'a V) -> HashTree<'a>, @@ -740,7 +740,7 @@ impl, V: AsHashTree + 'static> RbTree { /// Inserts a key-value entry into the map. pub fn insert(&mut self, key: K, value: V) { - fn go, V: AsHashTree + 'static>( + fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( h: NodeRef, k: K, v: V, @@ -778,7 +778,7 @@ impl, V: AsHashTree + 'static> RbTree { /// Removes the specified key from the map. pub fn delete(&mut self, key: &[u8]) { - fn move_red_left, V: AsHashTree + 'static>( + fn move_red_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( mut h: Box>, ) -> Box> { flip_colors(&mut h); @@ -790,7 +790,7 @@ impl, V: AsHashTree + 'static> RbTree { h } - fn move_red_right, V: AsHashTree + 'static>( + fn move_red_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( mut h: Box>, ) -> Box> { flip_colors(&mut h); @@ -802,7 +802,7 @@ impl, V: AsHashTree + 'static> RbTree { } #[inline] - fn min, V: AsHashTree + 'static>( + fn min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( mut h: &mut Box>, ) -> &mut Box> { while h.left.is_some() { @@ -811,7 +811,7 @@ impl, V: AsHashTree + 'static> RbTree { h } - fn delete_min, V: AsHashTree + 'static>( + fn delete_min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( mut h: Box>, ) -> NodeRef { if h.left.is_none() { @@ -827,7 +827,7 @@ impl, V: AsHashTree + 'static> RbTree { Some(balance(h)) } - fn go, V: AsHashTree + 'static>( + fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( mut h: Box>, key: &[u8], ) -> NodeRef { @@ -888,6 +888,94 @@ impl, V: AsHashTree + 'static> RbTree { } } +use candid::CandidType; + +impl<'t, K, V> CandidType for RbTree +where + K: CandidType + AsRef<[u8]> + 't, + V: CandidType + AsHashTree + 't, +{ + fn _ty() -> candid::types::internal::Type { + as CandidType>::_ty() + } + fn idl_serialize(&self, serializer: S) -> Result<(), S::Error> { + let collect_as_vec = self.iter().collect::>(); + as CandidType>::idl_serialize(&collect_as_vec, serializer) + } +} + +use serde::{ + de::{Deserialize, Deserializer, MapAccess, Visitor}, + ser::{Serialize, SerializeMap, Serializer}, +}; +use std::marker::PhantomData; + +impl<'t, K, V> Serialize for RbTree +where + K: Serialize + AsRef<[u8]> + 't, + V: Serialize + AsHashTree + 't, +{ + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.iter().count()))?; + for (k, v) in self.iter() { + map.serialize_entry(k, v)?; + } + map.end() + } +} + +// The PhantomData keeps the compiler from complaining about unused generic type parameters. +struct RbTreeSerdeVisitor { + marker: PhantomData RbTree>, +} + +impl RbTreeSerdeVisitor { + fn new() -> Self { + RbTreeSerdeVisitor { + marker: PhantomData, + } + } +} + +impl<'de, 't, K, V> Visitor<'de> for RbTreeSerdeVisitor +where + K: Deserialize<'de> + AsRef<[u8]> + 't, + V: Deserialize<'de> + AsHashTree + 't, +{ + type Value = RbTree; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut access: M) -> Result + where + M: MapAccess<'de>, + { + let mut t = RbTree::::new(); + while let Some((key, value)) = access.next_entry()? { + t.insert(key, value); + } + Ok(t) + } +} + +impl<'de, 't, K, V> Deserialize<'de> for RbTree +where + K: Deserialize<'de> + AsRef<[u8]> + 't, + V: Deserialize<'de> + AsHashTree + 't, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_map(RbTreeSerdeVisitor::new()) + } +} + fn three_way_fork<'a>(l: HashTree<'a>, m: HashTree<'a>, r: HashTree<'a>) -> HashTree<'a> { match (l, m, r) { (Empty, m, Empty) => m, @@ -906,9 +994,7 @@ fn is_red(x: &NodeRef) -> bool { x.as_ref().map(|h| h.color == Color::Red).unwrap_or(false) } -fn balance + 'static, V: AsHashTree + 'static>( - mut h: Box>, -) -> Box> { +fn balance<'t, K: AsRef<[u8]> + 't, V: AsHashTree + 't>(mut h: Box>) -> Box> { if is_red(&h.right) && !is_red(&h.left) { h = rotate_left(h); } @@ -922,7 +1008,7 @@ fn balance + 'static, V: AsHashTree + 'static>( } /// Make a left-leaning link lean to the right. -fn rotate_right, V: AsHashTree + 'static>( +fn rotate_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( mut h: Box>, ) -> Box> { debug_assert!(is_red(&h.left)); @@ -939,7 +1025,7 @@ fn rotate_right, V: AsHashTree + 'static>( x } -fn rotate_left, V: AsHashTree + 'static>( +fn rotate_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>( mut h: Box>, ) -> Box> { debug_assert!(is_red(&h.right)); diff --git a/library/ic-certified-map/src/rbtree/test.rs b/library/ic-certified-map/src/rbtree/test.rs index 206906757..40fa6a13f 100644 --- a/library/ic-certified-map/src/rbtree/test.rs +++ b/library/ic-certified-map/src/rbtree/test.rs @@ -375,3 +375,26 @@ fn test_ordering() { assert_eq!(t1.cmp(&t3), Greater); assert_eq!(t1.cmp(&t4), Less); } + +#[test] +fn test_serde_serialize_and_deserialize() { + type Tree<'a> = RbTree<&'a str, Hash>; + let t1: Tree<'_> = Tree::from_iter([("hi", [1; 32]), ("hello", [2; 32]), ("world", [3; 32])]); + + // cbor test + let mut b: Vec = Vec::new(); + serde_cbor::to_writer(&mut b, &t1).unwrap(); + let t2: Tree<'_> = serde_cbor::from_slice(&b[..]).unwrap(); + assert_eq!(t1, t2); + + // bincode test + use bincode::Options; + let b = bincode::options().serialize(&t1).unwrap(); + let t3: Tree<'_> = bincode::options().deserialize(&b).unwrap(); + assert_eq!(t1, t3); + + // candid test + let b = candid::encode_one(&t1).unwrap(); + let t4: Tree<'_> = candid::decode_one(&b).unwrap(); + assert_eq!(t1, t4); +}