Skip to content

Commit

Permalink
remove Encoded type (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Laine authored Feb 28, 2024
1 parent 983765c commit 3530acc
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 114 deletions.
98 changes: 24 additions & 74 deletions firewood/src/merkle/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use bitflags::bitflags;
use bytemuck::{CheckedBitPattern, NoUninit, Pod, Zeroable};
use enum_as_inner::EnumAsInner;
use serde::{
de::DeserializeOwned,
ser::{SerializeSeq, SerializeTuple},
Deserialize, Serialize,
};
Expand All @@ -25,6 +24,7 @@ use std::{
atomic::{AtomicBool, Ordering},
OnceLock,
},
vec,
};

mod branch;
Expand Down Expand Up @@ -69,35 +69,6 @@ impl Data {
}
}

#[derive(Serialize, Deserialize, Debug)]
enum Encoded<T> {
Raw(T),
Data(T),
}

impl Default for Encoded<Vec<u8>> {
fn default() -> Self {
// This is the default serialized empty vector
Encoded::Data(vec![0])
}
}

impl<T: DeserializeOwned + AsRef<[u8]>> Encoded<T> {
pub fn decode(self) -> Result<T, bincode::Error> {
match self {
Encoded::Raw(raw) => Ok(raw),
Encoded::Data(data) => bincode::DefaultOptions::new().deserialize(data.as_ref()),
}
}

pub fn deserialize<De: BinarySerde>(self) -> Result<T, De::DeserializeError> {
match self {
Encoded::Raw(raw) => Ok(raw),
Encoded::Data(data) => De::deserialize(data.as_ref()),
}
}
}

#[derive(PartialEq, Eq, Clone, Debug, EnumAsInner)]
pub enum NodeType {
Branch(Box<BranchNode>),
Expand All @@ -106,22 +77,22 @@ pub enum NodeType {

impl NodeType {
pub fn decode(buf: &[u8]) -> Result<NodeType, Error> {
let items: Vec<Encoded<Vec<u8>>> = bincode::DefaultOptions::new().deserialize(buf)?;
let items: Vec<Vec<u8>> = bincode::DefaultOptions::new().deserialize(buf)?;

match items.len() {
LEAF_NODE_SIZE => {
let mut items = items.into_iter();

#[allow(clippy::unwrap_used)]
let decoded_key: Vec<u8> = items.next().unwrap().decode()?;
let decoded_key: Vec<u8> = items.next().unwrap();

let decoded_key_nibbles = Nibbles::<0>::new(&decoded_key);

let cur_key_path = PartialPath::from_nibbles(decoded_key_nibbles.into_iter()).0;

let cur_key = cur_key_path.into_inner();
#[allow(clippy::unwrap_used)]
let data: Vec<u8> = items.next().unwrap().decode()?;
let data: Vec<u8> = items.next().unwrap();

Ok(NodeType::Leaf(LeafNode::new(cur_key, data)))
}
Expand Down Expand Up @@ -633,13 +604,11 @@ impl<'de> Deserialize<'de> for EncodedNode<PlainCodec> {
// Note that the serializer passed in should always be the same type as T in EncodedNode<T>.
impl Serialize for EncodedNode<Bincode> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::Error;

match &self.node {
EncodedNodeType::Leaf(n) => {
let list = [
Encoded::Raw(from_nibbles(&n.path.encode(true)).collect()),
Encoded::Raw(n.data.to_vec()),
from_nibbles(&n.path.encode(true)).collect(),
n.data.to_vec(),
];
let mut seq = serializer.serialize_seq(Some(list.len()))?;
for e in list {
Expand All @@ -653,7 +622,7 @@ impl Serialize for EncodedNode<Bincode> {
children,
value,
} => {
let mut list = <[Encoded<Vec<u8>>; BranchNode::MAX_CHILDREN + 2]>::default();
let mut list = <[Vec<u8>; BranchNode::MAX_CHILDREN + 2]>::default();
let children = children
.iter()
.enumerate()
Expand All @@ -662,27 +631,19 @@ impl Serialize for EncodedNode<Bincode> {
#[allow(clippy::indexing_slicing)]
for (i, child) in children {
if child.len() >= TRIE_HASH_LEN {
let serialized_hash =
Bincode::serialize(&Keccak256::digest(child).to_vec())
.map_err(|e| S::Error::custom(format!("bincode error: {e}")))?;
list[i] = Encoded::Data(serialized_hash);
let serialized_hash = Keccak256::digest(child).to_vec();
list[i] = serialized_hash;
} else {
list[i] = Encoded::Raw(child.to_vec());
list[i] = child.to_vec();
}
}

list[BranchNode::MAX_CHILDREN] = if let Some(Data(val)) = &value {
let serialized_val = Bincode::serialize(val)
.map_err(|e| S::Error::custom(format!("bincode error: {e}")))?;

Encoded::Data(serialized_val)
} else {
Encoded::default()
};
if let Some(Data(val)) = &value {
list[BranchNode::MAX_CHILDREN] = val.clone();
}

let serialized_path = from_nibbles(&path.encode(true)).collect();

list[BranchNode::MAX_CHILDREN + 1] = Encoded::Raw(serialized_path);
list[BranchNode::MAX_CHILDREN + 1] = serialized_path;

let mut seq = serializer.serialize_seq(Some(list.len()))?;

Expand All @@ -703,18 +664,18 @@ impl<'de> Deserialize<'de> for EncodedNode<Bincode> {
{
use serde::de::Error;

let mut items: Vec<Encoded<Vec<u8>>> = Deserialize::deserialize(deserializer)?;
let mut items: Vec<Vec<u8>> = Deserialize::deserialize(deserializer)?;
let len = items.len();

match len {
LEAF_NODE_SIZE => {
let mut items = items.into_iter();
let Some(Encoded::Raw(path)) = items.next() else {
let Some(path) = items.next() else {
return Err(D::Error::custom(
"incorrect encoded type for leaf node path",
));
};
let Some(Encoded::Raw(data)) = items.next() else {
let Some(data) = items.next() else {
return Err(D::Error::custom(
"incorrect encoded type for leaf node data",
));
Expand All @@ -728,30 +689,19 @@ impl<'de> Deserialize<'de> for EncodedNode<Bincode> {
}

BranchNode::MSIZE => {
let path = items
.pop()
.unwrap_or_default()
.deserialize::<Bincode>()
.map_err(D::Error::custom)?;
let path = items.pop().expect("length was checked above");
let path = PartialPath::from_nibbles(Nibbles::<0>::new(&path).into_iter()).0;

let value = items
.pop()
.unwrap_or_default()
.deserialize::<Bincode>()
.map_err(D::Error::custom)
.map(Data)
.map(Some)?
.filter(|data| !data.is_empty());
let value = items.pop().expect("length was checked above");
let value = if value.is_empty() {
None
} else {
Some(Data(value))
};

let mut children: [Option<Vec<u8>>; BranchNode::MAX_CHILDREN] = Default::default();

for (i, chd) in items.into_iter().enumerate() {
let chd = match chd {
Encoded::Raw(chd) => chd,
Encoded::Data(chd) => Bincode::deserialize(chd.as_ref())
.map_err(|e| D::Error::custom(format!("bincode error: {e}")))?,
};
#[allow(clippy::indexing_slicing)]
(children[i] = Some(chd).filter(|chd| !chd.is_empty()));
}
Expand Down
48 changes: 14 additions & 34 deletions firewood/src/merkle/node/branch.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE.md for licensing terms.

use super::{Data, Encoded, Node};
use super::{Data, Node};
use crate::{
merkle::{from_nibbles, to_nibble_array, PartialPath, TRIE_HASH_LEN},
merkle::{from_nibbles, to_nibble_array, PartialPath},
nibbles::Nibbles,
shale::{DiskAddress, ShaleError, ShaleStore, Storable},
};
Expand Down Expand Up @@ -96,18 +96,15 @@ impl BranchNode {
}

pub(super) fn decode(buf: &[u8]) -> Result<Self, Error> {
let mut items: Vec<Encoded<Vec<u8>>> = bincode::DefaultOptions::new().deserialize(buf)?;
let mut items: Vec<Vec<u8>> = bincode::DefaultOptions::new().deserialize(buf)?;

let path = items
.pop()
.ok_or(Error::custom("Invalid Branch Node"))?
.decode()?;
let path = items.pop().ok_or(Error::custom("Invalid Branch Node"))?;
let path = Nibbles::<0>::new(&path);
let (path, _term) = PartialPath::from_nibbles(path.into_iter());

// we've already validated the size, that's why we can safely unwrap
#[allow(clippy::unwrap_used)]
let data = items.pop().unwrap().decode()?;
let data = items.pop().unwrap();
// Extract the value of the branch node and set to None if it's an empty Vec
let value = Some(data).filter(|data| !data.is_empty());

Expand All @@ -116,9 +113,8 @@ impl BranchNode {

// we popped the last element, so their should only be NBRANCH items left
for (i, chd) in items.into_iter().enumerate() {
let data = chd.decode()?;
#[allow(clippy::indexing_slicing)]
(chd_encoded[i] = Some(data).filter(|data| !data.is_empty()));
(chd_encoded[i] = Some(chd).filter(|data| !data.is_empty()));
}

Ok(BranchNode::new(
Expand All @@ -131,7 +127,7 @@ impl BranchNode {

pub(super) fn encode<S: ShaleStore<Node>>(&self, store: &S) -> Vec<u8> {
// path + children + value
let mut list = <[Encoded<Vec<u8>>; Self::MSIZE]>::default();
let mut list = <[Vec<u8>; Self::MSIZE]>::default();

for (i, c) in self.children.iter().enumerate() {
match c {
Expand All @@ -142,21 +138,17 @@ impl BranchNode {
#[allow(clippy::unwrap_used)]
if c_ref.is_encoded_longer_than_hash_len::<S>(store) {
#[allow(clippy::indexing_slicing)]
(list[i] = Encoded::Data(
bincode::DefaultOptions::new()
.serialize(&&(*c_ref.get_root_hash::<S>(store))[..])
.unwrap(),
));
(list[i] = c_ref.get_root_hash::<S>(store).to_vec());

// See struct docs for ordering requirements
if c_ref.is_dirty() {
c_ref.write(|_| {}).unwrap();
c_ref.set_dirty(false);
}
} else {
let child_encoded = &c_ref.get_encoded::<S>(store);
let child_encoded = c_ref.get_encoded::<S>(store);
#[allow(clippy::indexing_slicing)]
(list[i] = Encoded::Raw(child_encoded.to_vec()));
(list[i] = child_encoded.to_vec());
}
}

Expand All @@ -171,34 +163,22 @@ impl BranchNode {
// can happen when manually constructing a trie from proof.
#[allow(clippy::indexing_slicing)]
if let Some(v) = &self.children_encoded[i] {
if v.len() == TRIE_HASH_LEN {
#[allow(clippy::indexing_slicing, clippy::unwrap_used)]
(list[i] = Encoded::Data(
bincode::DefaultOptions::new().serialize(v).unwrap(),
));
} else {
#[allow(clippy::indexing_slicing)]
(list[i] = Encoded::Raw(v.clone()));
}
#[allow(clippy::indexing_slicing)]
(list[i] = v.clone());
}
}
};
}

#[allow(clippy::unwrap_used)]
if let Some(Data(val)) = &self.value {
list[Self::MAX_CHILDREN] =
Encoded::Data(bincode::DefaultOptions::new().serialize(val).unwrap());
list[Self::MAX_CHILDREN] = val.clone();
}

#[allow(clippy::unwrap_used)]
let path = from_nibbles(&self.path.encode(false)).collect::<Vec<_>>();

list[Self::MAX_CHILDREN + 1] = Encoded::Data(
bincode::DefaultOptions::new()
.serialize(&path)
.expect("serializing raw bytes to always succeed"),
);
list[Self::MAX_CHILDREN + 1] = path;

bincode::DefaultOptions::new()
.serialize(list.as_slice())
Expand Down
11 changes: 5 additions & 6 deletions firewood/src/merkle/node/leaf.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE.md for licensing terms.

use super::{Data, Encoded};
use super::Data;
use crate::{
merkle::{from_nibbles, PartialPath},
nibbles::Nibbles,
Expand Down Expand Up @@ -53,8 +53,8 @@ impl LeafNode {
bincode::DefaultOptions::new()
.serialize(
[
Encoded::Raw(from_nibbles(&self.path.encode(true)).collect()),
Encoded::Raw(self.data.to_vec()),
from_nibbles(&self.path.encode(true)).collect(),
self.data.to_vec(),
]
.as_slice(),
)
Expand Down Expand Up @@ -156,9 +156,8 @@ mod tests {
let data = vec![5, 6, 7, 8];

let serialized_path = [vec![prefix], path.clone()].concat();
// 0 represents Encoded::Raw
let serialized_path = [vec![0, serialized_path.len() as u8], serialized_path].concat();
let serialized_data = [vec![0, data.len() as u8], data.clone()].concat();
let serialized_path = [vec![serialized_path.len() as u8], serialized_path].concat();
let serialized_data = [vec![data.len() as u8], data.clone()].concat();

let serialized = [vec![2], serialized_path, serialized_data].concat();

Expand Down

0 comments on commit 3530acc

Please sign in to comment.