From 2f43690ccd269296d2f5ab7c219081dd77d9ae27 Mon Sep 17 00:00:00 2001 From: ielm Date: Sat, 9 Mar 2024 18:44:57 -0500 Subject: [PATCH] impl high-level api & iterators --- src/avl_tree/node.rs | 17 ++- src/avl_tree/tree.rs | 261 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 275 insertions(+), 3 deletions(-) diff --git a/src/avl_tree/node.rs b/src/avl_tree/node.rs index 71b89d5..e55bc18 100644 --- a/src/avl_tree/node.rs +++ b/src/avl_tree/node.rs @@ -70,8 +70,8 @@ pub type OptNode = Option>>; /// It is of type `i32`. #[derive(Debug)] pub struct Node { - key: K, - value: V, + pub key: K, + pub value: V, parent: OptNode, left: OptNode, right: OptNode, @@ -90,6 +90,19 @@ impl Node { } } + #[inline] + // pub fn into_element(n: Box>) -> (K, V) { + pub fn into_element(node: OptNode) -> (K, V) { + let n = node.unwrap(); + + unsafe { + ( + std::ptr::read(&(*n.as_ptr()).key), + std::ptr::read(&(*n.as_ptr()).value), + ) + } + } + #[inline] pub fn get_value(node: OptNode) -> Option { if node.is_none() { diff --git a/src/avl_tree/tree.rs b/src/avl_tree/tree.rs index 074dd5c..a3fdc4f 100644 --- a/src/avl_tree/tree.rs +++ b/src/avl_tree/tree.rs @@ -1,7 +1,7 @@ use super::constants::BALANCE_THRESHOLD; use super::node::{Node, OptNode}; use std::cmp::Ordering; -use std::collections::VecDeque; +use std::collections::{HashSet, VecDeque}; use std::marker::PhantomData; use std::ptr::NonNull; @@ -32,6 +32,114 @@ pub struct AvlTree { _marker: PhantomData>>, } +impl AvlTree { + pub fn new() -> Self { + AvlTree { + root: None, + len: 0, + _marker: PhantomData, + } + } + + pub fn insert(&mut self, k: K, v: V) { + self._insert_kv(k, v); + } + + pub fn get(&self, k: &K) -> Option { + self._get_node(k).and_then(|n| Node::get_value(Some(n))) + } + + pub fn get_mut(&mut self, k: &K) -> Option<&'static mut V> { + self._get_mut_value(k) + } + + pub fn remove(&mut self, k: &K) -> Option { + self._remove_node(k).and_then(|n| Node::get_value(Some(n))) + } + + pub fn contains(&self, k: &K) -> bool { + self._get_node(k).is_some() + } + + pub fn peek_root(&self) -> Option<(K, V)> { + self.root.map(|n| unsafe { + let node = Box::from_raw(n.as_ptr()); + (node.key, node.value) + }) + } + + pub fn pop_max(&mut self) -> Option<(K, V)> { + self._pop_max().map(|n| unsafe { + let node = Box::from_raw(n.as_ptr()); + (node.key, node.value) + }) + } + + pub fn pop_max_boxed(&mut self) -> Option>> { + self._pop_max_boxed() + } + + pub fn pop_min(&mut self) -> Option<(K, V)> { + self._pop_min().map(|n| unsafe { + let node = Box::from_raw(n.as_ptr()); + (node.key, node.value) + }) + } + + pub fn pop_min_boxed(&mut self) -> Option>> { + self._pop_min_boxed() + } + + pub fn iter(&self) -> Iter<'_, K, V> { + Iter { + next_nodes: vec![self.root], + seen: HashSet::new(), + next_back_nodes: vec![self.root], + seen_back: HashSet::new(), + _marker: PhantomData, + } + } + + pub fn len(&self) -> isize { + self.len + } + + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn is_balanced(&self) -> bool { + self._is_balanced() + } + + pub fn clear(&mut self) { + *self = Self::new(); + } +} + +impl Default for AvlTree { + fn default() -> Self { + Self::new() + } +} + +impl Drop for AvlTree { + fn drop(&mut self) { + struct DropGuard<'a, K: Ord, V>(&'a mut AvlTree); + impl<'a, K: Ord, V> Drop for DropGuard<'a, K, V> { + fn drop(&mut self) { + self.0.clear(); + } + } + + while let Some(b) = self._pop_min_boxed() { + let guard = DropGuard(self); + drop(b); + std::mem::forget(guard); + } + } +} + impl AvlTree { /// Inserts a key-value pair into the AVL tree. /// @@ -679,3 +787,154 @@ impl AvlTree { } } } + +impl IntoIterator for AvlTree { + type Item = (K, V); + type IntoIter = IntoIter; + fn into_iter(self) -> Self::IntoIter { + IntoIter(self) + } +} + +impl FromIterator<(K, V)> for AvlTree { + fn from_iter>(iter: T) -> Self { + let mut tree = AvlTree::new(); + for (k, v) in iter { + tree.insert(k, v); + } + tree + } +} + +unsafe impl Send for AvlTree {} + +unsafe impl Sync for AvlTree {} + +pub struct IntoIter(AvlTree); + +impl Iterator for IntoIter { + type Item = (K, V); + fn next(&mut self) -> Option { + self.0._pop_min().map(|n| Node::into_element(Some(n))) + } +} + +impl DoubleEndedIterator for IntoIter { + fn next_back(&mut self) -> Option { + self.0._pop_max().map(|n| Node::into_element(Some(n))) + } +} + +impl Drop for IntoIter { + fn drop(&mut self) { + struct DropGuard<'a, K: Ord, V>(&'a mut IntoIter); + + impl<'a, K: Ord, V> Drop for DropGuard<'a, K, V> { + fn drop(&mut self) { + for _ in self.0.by_ref() {} + } + } + + while let Some(d) = self.next() { + let guard = DropGuard(self); + drop(d); + std::mem::forget(guard); + } + } +} + +pub struct Iter<'a, K: Ord, V> { + next_nodes: Vec>, + seen: HashSet>>, + next_back_nodes: Vec>, + seen_back: HashSet>>, + _marker: PhantomData<&'a Node>, +} + +impl<'a, K: Ord, V> Iter<'a, K, V> { + fn next_ascending(&mut self) -> OptNode { + while let Some(node) = self.next_nodes.pop() { + let left = node + .as_ref() + .and_then(|n| Node::get_left(Some(*n))) + .filter(|n| !self.seen.contains(n)); + let right = node + .as_ref() + .and_then(|n| Node::get_right(Some(*n))) + .filter(|n| !self.seen.contains(n)); + + if left.is_some() && right.is_some() { + self.next_nodes.push(node); + self.next_nodes.push(left); + } else if left.is_some() { + self.next_nodes.push(node); + } else { + if right.is_some() { + self.next_nodes.push(right); + if let Some(n) = node { + self.seen.insert(n); + } + return node; + } + if let Some(n) = node { + self.seen.insert(n); + } + return node; + } + } + None + } + + fn next_descending(&mut self) -> OptNode { + while let Some(node) = self.next_back_nodes.pop() { + let left = node + .as_ref() + .and_then(|n| Node::get_left(Some(*n))) + .filter(|n| !self.seen_back.contains(n)); + let right = node + .as_ref() + .and_then(|n| Node::get_right(Some(*n))) + .filter(|n| !self.seen_back.contains(n)); + + if left.is_some() && right.is_some() { + self.next_back_nodes.push(node); + self.next_back_nodes.push(right); + } else { + if left.is_some() { + self.next_back_nodes.push(left); + if let Some(n) = node { + self.seen_back.insert(n); + } + return node; + } + if right.is_some() { + self.next_back_nodes.push(node); + self.next_back_nodes.push(right); + } else { + if let Some(n) = node { + self.seen.insert(n); + } + return node; + } + } + } + None + } +} + +impl<'a, K: Ord, V> Iterator for Iter<'a, K, V> { + type Item = (&'a K, &'a V); + fn next(&mut self) -> Option { + self.next_ascending() + .as_ref() + .map(|n| unsafe { (&(*n.as_ptr()).key, &(*n.as_ptr()).value) }) + } +} + +impl<'a, K: Ord, V> DoubleEndedIterator for Iter<'a, K, V> { + fn next_back(&mut self) -> Option { + self.next_descending() + .as_ref() + .map(|n| unsafe { (&(*n.as_ptr()).key, &(*n.as_ptr()).value) }) + } +}