diff --git a/.github/workflows/psql.yml b/.github/workflows/psql.yml index 42c108c..4bd9e90 100644 --- a/.github/workflows/psql.yml +++ b/.github/workflows/psql.yml @@ -4,6 +4,7 @@ on: pull_request: paths: - '.github/workflows/psql.yml' + - 'crates/**' - 'src/**' - 'Cargo.lock' - 'Cargo.toml' @@ -16,6 +17,7 @@ on: - main paths: - '.github/workflows/psql.yml' + - 'crates/**' - 'src/**' - 'Cargo.lock' - 'Cargo.toml' diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8b1b4e2..f026d69 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -4,6 +4,7 @@ on: pull_request: paths: - '.github/workflows/rust.yml' + - 'crates/**' - 'src/**' - 'Cargo.lock' - 'Cargo.toml' @@ -14,6 +15,7 @@ on: - main paths: - '.github/workflows/rust.yml' + - 'crates/**' - 'src/**' - 'Cargo.lock' - 'Cargo.toml' diff --git a/crates/algorithm/src/build.rs b/crates/algorithm/src/build.rs index 3db4e27..d21ab9d 100644 --- a/crates/algorithm/src/build.rs +++ b/crates/algorithm/src/build.rs @@ -81,6 +81,7 @@ pub fn build( factor_err: code.factor_err, signs: code.signs, first: pointer_of_firsts[i - 1][child as usize], + size: structures[i - 1].len() as _, }); } let tape = tape.into_inner(); @@ -96,6 +97,7 @@ pub fn build( vectors_first: vectors.first(), root_mean: pointer_of_means.last().unwrap()[0], root_first: pointer_of_firsts.last().unwrap()[0], + root_size: structures.last().unwrap().len() as _, freepage_first: freepage.first(), }); } diff --git a/crates/algorithm/src/bulkdelete.rs b/crates/algorithm/src/bulkdelete.rs index 524ec90..9d1c0bc 100644 --- a/crates/algorithm/src/bulkdelete.rs +++ b/crates/algorithm/src/bulkdelete.rs @@ -16,12 +16,14 @@ pub fn bulkdelete( let vectors_first = meta_tuple.vectors_first(); drop(meta_guard); { - type State = Vec; - let mut state: State = vec![root_first]; - let step = |state: State| { + struct State { + first: u32, + } + let mut states: Vec = vec![State { first: root_first }]; + let step = |states: Vec| { let mut results = Vec::new(); - for first in state { - let mut current = first; + for state in states { + let mut current = state.first; while current != u32::MAX { let h1_guard = index.read(current); for i in 1..=h1_guard.len() { @@ -32,7 +34,7 @@ pub fn bulkdelete( match h1_tuple { H1TupleReader::_0(h1_tuple) => { for first in h1_tuple.first().iter().copied() { - results.push(first); + results.push(State { first }); } } H1TupleReader::_1(_) => (), @@ -44,10 +46,10 @@ pub fn bulkdelete( results }; for _ in (1..height_of_root).rev() { - state = step(state); + states = step(states); } - for first in state { - let jump_guard = index.read(first); + for state in states { + let jump_guard = index.read(state.first); let jump_tuple = jump_guard .get(1) .expect("data corruption") diff --git a/crates/algorithm/src/insert.rs b/crates/algorithm/src/insert.rs index 433e600..37a3a58 100644 --- a/crates/algorithm/src/insert.rs +++ b/crates/algorithm/src/insert.rs @@ -20,6 +20,7 @@ pub fn insert(index: impl RelationWrite, payload: NonZeroU64, vecto assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions"); let root_mean = meta_tuple.root_mean(); let root_first = meta_tuple.root_first(); + let root_size = meta_tuple.root_size(); let vectors_first = meta_tuple.vectors_first(); drop(meta_guard); @@ -31,7 +32,11 @@ pub fn insert(index: impl RelationWrite, payload: NonZeroU64, vecto let mean = vectors::append::(index.clone(), vectors_first, vector.as_borrowed(), payload); - type State = (u32, Option<::Vector>); + struct State { + first: u32, + residual: Option, + size: u32, + } let mut state: State = { let mean = root_mean; if is_residual { @@ -43,39 +48,52 @@ pub fn insert(index: impl RelationWrite, payload: NonZeroU64, vecto O::ResidualAccessor::default(), ), ); - (root_first, Some(residual_u)) + State { + residual: Some(residual_u), + first: root_first, + size: root_size, + } } else { - (root_first, None) + State { + residual: None, + first: root_first, + size: root_size, + } } }; let step = |state: State| { - let mut results = Vec::new(); + let mut results = Vec::with_capacity(state.size as _); { - let (first, residual) = state; - let lut = if let Some(residual) = residual { + let lut = if let Some(residual) = state.residual { &O::Vector::compute_lut_block(residual.as_borrowed()) } else { default_lut_block.as_ref().unwrap() }; access_1( index.clone(), - first, + state.first, || { RAccess::new( (&lut.4, (lut.0, lut.1, lut.2, lut.3, 1.9f32)), O::Distance::block_accessor(), ) }, - |lowerbound, mean, first| { - results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(first))); + |lowerbound, mean, first, size| { + results.push(( + Reverse(lowerbound), + AlwaysEqual(mean), + AlwaysEqual(first), + AlwaysEqual(size), + )); }, ); } let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); + let mut cache = BinaryHeap::<(Reverse, _)>::new(); { while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); + let (_, AlwaysEqual(mean), AlwaysEqual(first), AlwaysEqual(size)) = + heap.pop().unwrap(); if is_residual { let (dis_u, residual_u) = vectors::access_1::( index.clone(), @@ -90,8 +108,11 @@ pub fn insert(index: impl RelationWrite, payload: NonZeroU64, vecto ); cache.push(( Reverse(dis_u), - AlwaysEqual(first), - AlwaysEqual(Some(residual_u)), + AlwaysEqual(State { + residual: Some(residual_u), + first, + size, + }), )); } else { let dis_u = vectors::access_1::( @@ -102,19 +123,25 @@ pub fn insert(index: impl RelationWrite, payload: NonZeroU64, vecto O::DistanceAccessor::default(), ), ); - cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(None))); + cache.push(( + Reverse(dis_u), + AlwaysEqual(State { + residual: None, + first, + size, + }), + )); } } - let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop().unwrap(); - (first, mean) + let (_, AlwaysEqual(state)) = cache.pop().unwrap(); + state } }; for _ in (1..height_of_root).rev() { state = step(state); } - let (first, residual) = state; - let code = if let Some(residual) = residual { + let code = if let Some(residual) = state.residual { O::Vector::code(residual.as_borrowed()) } else { O::Vector::code(vector.as_borrowed()) @@ -129,7 +156,7 @@ pub fn insert(index: impl RelationWrite, payload: NonZeroU64, vecto elements: rabitq::pack_to_u64(&code.signs), }); - let jump_guard = index.read(first); + let jump_guard = index.read(state.first); let jump_tuple = jump_guard .get(1) .expect("data corruption") diff --git a/crates/algorithm/src/maintain.rs b/crates/algorithm/src/maintain.rs index fec057b..f9c95f9 100644 --- a/crates/algorithm/src/maintain.rs +++ b/crates/algorithm/src/maintain.rs @@ -14,13 +14,15 @@ pub fn maintain(index: impl RelationWrite, check: impl Fn()) { let freepage_first = meta_tuple.freepage_first(); drop(meta_guard); - let firsts = { - type State = Vec; - let mut state: State = vec![root_first]; - let step = |state: State| { + let states = { + struct State { + first: u32, + } + let mut states: Vec = vec![State { first: root_first }]; + let step = |states: Vec| { let mut results = Vec::new(); - for first in state { - let mut current = first; + for state in states { + let mut current = state.first; while current != u32::MAX { check(); let h1_guard = index.read(current); @@ -32,7 +34,7 @@ pub fn maintain(index: impl RelationWrite, check: impl Fn()) { match h1_tuple { H1TupleReader::_0(h1_tuple) => { for first in h1_tuple.first().iter().copied() { - results.push(first); + results.push(State { first }); } } H1TupleReader::_1(_) => (), @@ -44,13 +46,13 @@ pub fn maintain(index: impl RelationWrite, check: impl Fn()) { results }; for _ in (1..height_of_root).rev() { - state = step(state); + states = step(states); } - state + states }; - for first in firsts { - let mut jump_guard = index.write(first, false); + for state in states { + let mut jump_guard = index.write(state.first, false); let mut jump_tuple = jump_guard .get_mut(1) .expect("data corruption") diff --git a/crates/algorithm/src/prewarm.rs b/crates/algorithm/src/prewarm.rs index 587f752..0a867fd 100644 --- a/crates/algorithm/src/prewarm.rs +++ b/crates/algorithm/src/prewarm.rs @@ -10,6 +10,7 @@ pub fn prewarm(index: impl RelationRead, height: i32, check: impl F let height_of_root = meta_tuple.height_of_root(); let root_mean = meta_tuple.root_mean(); let root_first = meta_tuple.root_first(); + let root_size = meta_tuple.root_size(); drop(meta_guard); let mut message = String::new(); @@ -18,25 +19,27 @@ pub fn prewarm(index: impl RelationRead, height: i32, check: impl F if prewarm_max_height > height_of_root { return message; } - type State = Vec; - let mut state: State = { - let mut nodes = Vec::new(); - { - vectors::access_1::(index.clone(), root_mean, ()); - nodes.push(root_first); - } + struct State { + first: u32, + size: u32, + } + let mut states: Vec = { + vectors::access_1::(index.clone(), root_mean, ()); writeln!(message, "------------------------").unwrap(); - writeln!(message, "number of nodes: {}", nodes.len()).unwrap(); + writeln!(message, "number of nodes: {}", 1).unwrap(); writeln!(message, "number of tuples: {}", 1).unwrap(); writeln!(message, "number of pages: {}", 1).unwrap(); - nodes + vec![State { + first: root_first, + size: root_size, + }] }; - let mut step = |state: State| { + let mut step = |states: Vec| { let mut counter_pages = 0_usize; let mut counter_tuples = 0_usize; - let mut nodes = Vec::new(); - for list in state { - let mut current = list; + let mut nodes = Vec::with_capacity(states.iter().map(|x| x.size).sum::() as _); + for state in states { + let mut current = state.first; while current != u32::MAX { counter_pages += 1; check(); @@ -52,8 +55,11 @@ pub fn prewarm(index: impl RelationRead, height: i32, check: impl F for mean in h1_tuple.mean().iter().copied() { vectors::access_1::(index.clone(), mean, ()); } - for first in h1_tuple.first().iter().copied() { - nodes.push(first); + for j in 0..h1_tuple.len() { + nodes.push(State { + first: h1_tuple.first()[j as usize], + size: h1_tuple.size()[j as usize], + }); } } H1TupleReader::_1(_) => (), @@ -69,14 +75,14 @@ pub fn prewarm(index: impl RelationRead, height: i32, check: impl F nodes }; for _ in (std::cmp::max(1, prewarm_max_height)..height_of_root).rev() { - state = step(state); + states = step(states); } if prewarm_max_height == 0 { let mut counter_pages = 0_usize; let mut counter_tuples = 0_usize; let mut counter_nodes = 0_usize; - for list in state { - let jump_guard = index.read(list); + for state in states { + let jump_guard = index.read(state.first); let jump_tuple = jump_guard .get(1) .expect("data corruption") diff --git a/crates/algorithm/src/search.rs b/crates/algorithm/src/search.rs index bff53e3..65358cb 100644 --- a/crates/algorithm/src/search.rs +++ b/crates/algorithm/src/search.rs @@ -25,6 +25,7 @@ pub fn search( assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes"); let root_mean = meta_tuple.root_mean(); let root_first = meta_tuple.root_first(); + let root_size = meta_tuple.root_size(); drop(meta_guard); let default_lut = if !is_residual { @@ -33,8 +34,12 @@ pub fn search( None }; - type State = Vec<(u32, Option<::Vector>)>; - let mut state: State = vec![{ + struct State { + residual: Option, + first: u32, + size: u32, + } + let mut states: Vec> = vec![{ let mean = root_mean; if is_residual { let residual_u = vectors::access_1::( @@ -45,38 +50,52 @@ pub fn search( O::ResidualAccessor::default(), ), ); - (root_first, Some(residual_u)) + State { + residual: Some(residual_u), + first: root_first, + size: root_size, + } } else { - (root_first, None) + State { + residual: None, + first: root_first, + size: root_size, + } } }]; - let step = |state: State, probes| { - let mut results = Vec::new(); - for (first, residual) in state { - let lut = if let Some(residual) = residual { + let step = |states: Vec>, probes| { + let mut results = Vec::with_capacity(states.iter().map(|x| x.size).sum::() as _); + for state in states { + let lut = if let Some(residual) = state.residual { &O::Vector::compute_lut_block(residual.as_borrowed()) } else { default_lut.as_ref().map(|x| &x.0).unwrap() }; access_1( index.clone(), - first, + state.first, || { RAccess::new( (&lut.4, (lut.0, lut.1, lut.2, lut.3, epsilon)), O::Distance::block_accessor(), ) }, - |lowerbound, mean, first| { - results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(first))); + |lowerbound, mean, first, size| { + results.push(( + Reverse(lowerbound), + AlwaysEqual(mean), + AlwaysEqual(first), + AlwaysEqual(size), + )); }, ); } let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); + let mut cache = BinaryHeap::<(Reverse, _)>::new(); std::iter::from_fn(|| { while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); + let (_, AlwaysEqual(mean), AlwaysEqual(first), AlwaysEqual(size)) = + heap.pop().unwrap(); if is_residual { let (dis_u, residual_u) = vectors::access_1::( index.clone(), @@ -91,8 +110,11 @@ pub fn search( ); cache.push(( Reverse(dis_u), - AlwaysEqual(first), - AlwaysEqual(Some(residual_u)), + AlwaysEqual(State { + residual: Some(residual_u), + first, + size, + }), )); } else { let dis_u = vectors::access_1::( @@ -103,27 +125,34 @@ pub fn search( O::DistanceAccessor::default(), ), ); - cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(None))); + cache.push(( + Reverse(dis_u), + AlwaysEqual(State { + residual: None, + first, + size, + }), + )); } } - let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop()?; - Some((first, mean)) + let (_, AlwaysEqual(state)) = cache.pop()?; + Some(state) }) .take(probes as usize) .collect() }; for i in (1..height_of_root).rev() { - state = step(state, probes[i as usize - 1]); + states = step(states, probes[i as usize - 1]); } let mut results = Vec::new(); - for (first, residual) in state { - let lut = if let Some(residual) = residual.as_ref().map(|x| x.as_borrowed()) { + for state in states { + let lut = if let Some(residual) = state.residual.as_ref().map(|x| x.as_borrowed()) { &O::Vector::compute_lut(residual) } else { default_lut.as_ref().unwrap() }; - let jump_guard = index.read(first); + let jump_guard = index.read(state.first); let jump_tuple = jump_guard .get(1) .expect("data corruption") diff --git a/crates/algorithm/src/tape.rs b/crates/algorithm/src/tape.rs index edc1c4e..2015758 100644 --- a/crates/algorithm/src/tape.rs +++ b/crates/algorithm/src/tape.rs @@ -88,6 +88,7 @@ pub struct H1Branch { pub factor_err: f32, pub signs: Vec, pub first: u32, + pub size: u32, } pub struct H1TapeWriter { @@ -123,6 +124,7 @@ where factor_ip: chunk.each_ref().map(|x| x.factor_ip), factor_err: chunk.each_ref().map(|x| x.factor_err), first: chunk.each_ref().map(|x| x.first), + size: chunk.each_ref().map(|x| x.size), len: chunk.len() as _, elements: remain, }); @@ -155,6 +157,7 @@ where factor_ip: any_pack(chunk.iter().map(|x| x.factor_ip)), factor_err: any_pack(chunk.iter().map(|x| x.factor_err)), first: any_pack(chunk.iter().map(|x| x.first)), + size: any_pack(chunk.iter().map(|x| x.size)), len: chunk.len() as _, elements: remain, }); @@ -254,7 +257,7 @@ pub fn access_1( index: impl RelationRead, first: u32, make_block_accessor: impl Fn() -> A + Copy, - mut callback: impl FnMut(Distance, IndexPointer, u32), + mut callback: impl FnMut(Distance, IndexPointer, u32, u32), ) where A: for<'a> Accessor1< [u8; 16], @@ -282,6 +285,7 @@ pub fn access_1( lowerbounds[i as usize], h1_tuple.mean()[i as usize], h1_tuple.first()[i as usize], + h1_tuple.size()[i as usize], ); } } diff --git a/crates/algorithm/src/tuples.rs b/crates/algorithm/src/tuples.rs index 54c1707..554887f 100644 --- a/crates/algorithm/src/tuples.rs +++ b/crates/algorithm/src/tuples.rs @@ -7,7 +7,7 @@ use zerocopy_derive::{FromBytes, Immutable, IntoBytes, KnownLayout}; pub const ALIGN: usize = 8; pub type Tag = u64; const MAGIC: u64 = u64::from_ne_bytes(*b"vchordrq"); -const VERSION: u64 = 1; +const VERSION: u64 = 2; pub trait Tuple: 'static { type Reader<'a>: TupleReader<'a, Tuple = Self>; @@ -52,11 +52,11 @@ struct MetaTupleHeader { is_residual: Bool, _padding_0: [ZeroU8; 3], vectors_first: u32, - // raw vector root_mean: IndexPointer, - // for meta tuple, it's pointers to next level root_first: u32, + root_size: u32, freepage_first: u32, + _padding_1: [ZeroU8; 4], } pub struct MetaTuple { @@ -66,6 +66,7 @@ pub struct MetaTuple { pub vectors_first: u32, pub root_mean: IndexPointer, pub root_first: u32, + pub root_size: u32, pub freepage_first: u32, } @@ -83,7 +84,9 @@ impl Tuple for MetaTuple { vectors_first: self.vectors_first, root_mean: self.root_mean, root_first: self.root_first, + root_size: self.root_size, freepage_first: self.freepage_first, + _padding_1: Default::default(), } .as_bytes() .to_vec() @@ -134,6 +137,9 @@ impl MetaTupleReader<'_> { pub fn root_first(self) -> u32 { self.header.root_first } + pub fn root_size(self) -> u32 { + self.header.root_size + } pub fn freepage_first(self) -> u32 { self.header.freepage_first } @@ -416,6 +422,7 @@ struct H1TupleHeader0 { factor_ip: [f32; 32], factor_err: [f32; 32], first: [u32; 32], + size: [u32; 32], len: u32, _padding_0: [ZeroU8; 4], elements_s: usize, @@ -439,6 +446,7 @@ pub enum H1Tuple { factor_ip: [f32; 32], factor_err: [f32; 32], first: [u32; 32], + size: [u32; 32], len: u32, elements: Vec<[u8; 16]>, }, @@ -483,6 +491,7 @@ impl Tuple for H1Tuple { factor_ip, factor_err, first, + size, len, elements, } => { @@ -503,6 +512,7 @@ impl Tuple for H1Tuple { factor_err: *factor_err, first: *first, len: *len, + size: *size, _padding_0: Default::default(), elements_s, elements_e, @@ -591,6 +601,9 @@ impl<'a> H1TupleReader0<'a> { pub fn first(self) -> &'a [u32] { &self.header.first[..self.header.len as usize] } + pub fn size(self) -> &'a [u32] { + &self.header.size[..self.header.len as usize] + } pub fn elements(&self) -> &'a [[u8; 16]] { self.elements } diff --git a/crates/simd/cshim.c b/crates/simd/cshim.c index 1374bbf..10b5cee 100644 --- a/crates/simd/cshim.c +++ b/crates/simd/cshim.c @@ -2,11 +2,10 @@ #error "clang version must be >= 16" #endif -#include -#include - #ifdef __aarch64__ +#include +#include #include #include