Skip to content

Commit

Permalink
impl high-level api & iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
ielm committed Mar 9, 2024
1 parent fd2aa98 commit 2f43690
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 3 deletions.
17 changes: 15 additions & 2 deletions src/avl_tree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub type OptNode<K, V> = Option<NonNull<Node<K, V>>>;
/// It is of type `i32`.
#[derive(Debug)]
pub struct Node<K: Ord, V> {
key: K,
value: V,
pub key: K,
pub value: V,
parent: OptNode<K, V>,
left: OptNode<K, V>,
right: OptNode<K, V>,
Expand All @@ -90,6 +90,19 @@ impl<K: Ord, V> Node<K, V> {
}
}

#[inline]
// pub fn into_element(n: Box<Node<K, V>>) -> (K, V) {
pub fn into_element(node: OptNode<K, V>) -> (K, V) {
let n = node.unwrap();

unsafe {
(
std::ptr::read(&(*n.as_ptr()).key),
std::ptr::read(&(*n.as_ptr()).value),
)
}
}

#[inline]
pub fn get_value(node: OptNode<K, V>) -> Option<V> {
if node.is_none() {
Expand Down
261 changes: 260 additions & 1 deletion src/avl_tree/tree.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::constants::BALANCE_THRESHOLD;
use super::node::{Node, OptNode};
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::collections::{HashSet, VecDeque};
use std::marker::PhantomData;
use std::ptr::NonNull;

Expand Down Expand Up @@ -32,6 +32,114 @@ pub struct AvlTree<K: Ord, V> {
_marker: PhantomData<Box<Node<K, V>>>,
}

impl<K: Ord, V> AvlTree<K, V> {
pub fn new() -> Self {
AvlTree {
root: None,
len: 0,
_marker: PhantomData,
}
}

pub fn insert(&mut self, k: K, v: V) {
self._insert_kv(k, v);
}

pub fn get(&self, k: &K) -> Option<V> {
self._get_node(k).and_then(|n| Node::get_value(Some(n)))
}

pub fn get_mut(&mut self, k: &K) -> Option<&'static mut V> {
self._get_mut_value(k)
}

pub fn remove(&mut self, k: &K) -> Option<V> {
self._remove_node(k).and_then(|n| Node::get_value(Some(n)))
}

pub fn contains(&self, k: &K) -> bool {
self._get_node(k).is_some()
}

pub fn peek_root(&self) -> Option<(K, V)> {
self.root.map(|n| unsafe {
let node = Box::from_raw(n.as_ptr());
(node.key, node.value)
})
}

pub fn pop_max(&mut self) -> Option<(K, V)> {
self._pop_max().map(|n| unsafe {
let node = Box::from_raw(n.as_ptr());
(node.key, node.value)
})
}

pub fn pop_max_boxed(&mut self) -> Option<Box<Node<K, V>>> {
self._pop_max_boxed()
}

pub fn pop_min(&mut self) -> Option<(K, V)> {
self._pop_min().map(|n| unsafe {
let node = Box::from_raw(n.as_ptr());
(node.key, node.value)
})
}

pub fn pop_min_boxed(&mut self) -> Option<Box<Node<K, V>>> {
self._pop_min_boxed()
}

pub fn iter(&self) -> Iter<'_, K, V> {
Iter {
next_nodes: vec![self.root],
seen: HashSet::new(),
next_back_nodes: vec![self.root],
seen_back: HashSet::new(),
_marker: PhantomData,
}
}

pub fn len(&self) -> isize {
self.len
}

pub fn is_empty(&self) -> bool {
self.len == 0
}

pub fn is_balanced(&self) -> bool {
self._is_balanced()
}

pub fn clear(&mut self) {
*self = Self::new();
}
}

impl<K: Ord, V> Default for AvlTree<K, V> {
fn default() -> Self {
Self::new()
}
}

impl<K: Ord, V> Drop for AvlTree<K, V> {
fn drop(&mut self) {
struct DropGuard<'a, K: Ord, V>(&'a mut AvlTree<K, V>);
impl<'a, K: Ord, V> Drop for DropGuard<'a, K, V> {
fn drop(&mut self) {
self.0.clear();
}
}

while let Some(b) = self._pop_min_boxed() {
let guard = DropGuard(self);
drop(b);
std::mem::forget(guard);
}
}
}

impl<K: Ord, V> AvlTree<K, V> {
/// Inserts a key-value pair into the AVL tree.
///
Expand Down Expand Up @@ -679,3 +787,154 @@ impl<K: Ord, V> AvlTree<K, V> {
}
}
}

impl<K: Ord, V> IntoIterator for AvlTree<K, V> {
type Item = (K, V);
type IntoIter = IntoIter<K, V>;
fn into_iter(self) -> Self::IntoIter {
IntoIter(self)
}
}

impl<K: Ord, V> FromIterator<(K, V)> for AvlTree<K, V> {
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
let mut tree = AvlTree::new();
for (k, v) in iter {
tree.insert(k, v);
}
tree
}
}

unsafe impl<K: Ord + Send, V: Send> Send for AvlTree<K, V> {}

unsafe impl<K: Ord + Sync, V: Sync> Sync for AvlTree<K, V> {}

pub struct IntoIter<K: Ord, V>(AvlTree<K, V>);

impl<K: Ord, V> Iterator for IntoIter<K, V> {
type Item = (K, V);
fn next(&mut self) -> Option<Self::Item> {
self.0._pop_min().map(|n| Node::into_element(Some(n)))
}
}

impl<K: Ord, V> DoubleEndedIterator for IntoIter<K, V> {
fn next_back(&mut self) -> Option<Self::Item> {
self.0._pop_max().map(|n| Node::into_element(Some(n)))
}
}

impl<K: Ord, V> Drop for IntoIter<K, V> {
fn drop(&mut self) {
struct DropGuard<'a, K: Ord, V>(&'a mut IntoIter<K, V>);

impl<'a, K: Ord, V> Drop for DropGuard<'a, K, V> {
fn drop(&mut self) {
for _ in self.0.by_ref() {}
}
}

while let Some(d) = self.next() {
let guard = DropGuard(self);
drop(d);
std::mem::forget(guard);
}
}
}

pub struct Iter<'a, K: Ord, V> {
next_nodes: Vec<OptNode<K, V>>,
seen: HashSet<NonNull<Node<K, V>>>,
next_back_nodes: Vec<OptNode<K, V>>,
seen_back: HashSet<NonNull<Node<K, V>>>,
_marker: PhantomData<&'a Node<K, V>>,
}

impl<'a, K: Ord, V> Iter<'a, K, V> {
fn next_ascending(&mut self) -> OptNode<K, V> {
while let Some(node) = self.next_nodes.pop() {
let left = node
.as_ref()
.and_then(|n| Node::get_left(Some(*n)))
.filter(|n| !self.seen.contains(n));
let right = node
.as_ref()
.and_then(|n| Node::get_right(Some(*n)))
.filter(|n| !self.seen.contains(n));

if left.is_some() && right.is_some() {
self.next_nodes.push(node);
self.next_nodes.push(left);
} else if left.is_some() {
self.next_nodes.push(node);
} else {
if right.is_some() {
self.next_nodes.push(right);
if let Some(n) = node {
self.seen.insert(n);
}
return node;
}
if let Some(n) = node {
self.seen.insert(n);
}
return node;
}
}
None
}

fn next_descending(&mut self) -> OptNode<K, V> {
while let Some(node) = self.next_back_nodes.pop() {
let left = node
.as_ref()
.and_then(|n| Node::get_left(Some(*n)))
.filter(|n| !self.seen_back.contains(n));
let right = node
.as_ref()
.and_then(|n| Node::get_right(Some(*n)))
.filter(|n| !self.seen_back.contains(n));

if left.is_some() && right.is_some() {
self.next_back_nodes.push(node);
self.next_back_nodes.push(right);
} else {
if left.is_some() {
self.next_back_nodes.push(left);
if let Some(n) = node {
self.seen_back.insert(n);
}
return node;
}
if right.is_some() {
self.next_back_nodes.push(node);
self.next_back_nodes.push(right);
} else {
if let Some(n) = node {
self.seen.insert(n);
}
return node;
}
}
}
None
}
}

impl<'a, K: Ord, V> Iterator for Iter<'a, K, V> {
type Item = (&'a K, &'a V);
fn next(&mut self) -> Option<Self::Item> {
self.next_ascending()
.as_ref()
.map(|n| unsafe { (&(*n.as_ptr()).key, &(*n.as_ptr()).value) })
}
}

impl<'a, K: Ord, V> DoubleEndedIterator for Iter<'a, K, V> {
fn next_back(&mut self) -> Option<Self::Item> {
self.next_descending()
.as_ref()
.map(|n| unsafe { (&(*n.as_ptr()).key, &(*n.as_ptr()).value) })
}
}

0 comments on commit 2f43690

Please sign in to comment.