Skip to content

Commit

Permalink
fix: reserve capacity
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Jan 26, 2025
1 parent a540079 commit 7c58c3b
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 82 deletions.
2 changes: 2 additions & 0 deletions src/algorithm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pub fn build<O: Operator, T: HeapRelation<O>, R: Reporter>(
factor_err: code.factor_err,
signs: code.signs,
first: pointer_of_firsts[i - 1][child as usize],
size: structures[i - 1].len() as _,
});
}
let tape = tape.into_inner();
Expand All @@ -144,6 +145,7 @@ pub fn build<O: Operator, T: HeapRelation<O>, R: Reporter>(
vectors_first: vectors.first(),
root_mean: pointer_of_means.last().unwrap()[0],
root_first: pointer_of_firsts.last().unwrap()[0],
root_size: structures.last().unwrap().len() as _,
freepage_first: freepage.first(),
});
}
Expand Down
65 changes: 46 additions & 19 deletions src/algorithm/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub fn insert<O: Operator>(
assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions");
let root_mean = meta_tuple.root_mean();
let root_first = meta_tuple.root_first();
let root_size = meta_tuple.root_size();
let vectors_first = meta_tuple.vectors_first();
drop(meta_guard);

Expand All @@ -42,7 +43,11 @@ pub fn insert<O: Operator>(
payload,
);

type State<O> = (u32, Option<<O as Operator>::Vector>);
struct State<O: Operator> {
first: u32,
residual: Option<O::Vector>,
size: u32,
}
let mut state: State<O> = {
let mean = root_mean;
if is_residual {
Expand All @@ -54,39 +59,52 @@ pub fn insert<O: Operator>(
O::ResidualAccessor::default(),
),
);
(root_first, Some(residual_u))
State {
residual: Some(residual_u),
first: root_first,
size: root_size,
}
} else {
(root_first, None)
State {
residual: None,
first: root_first,
size: root_size,
}
}
};
let step = |state: State<O>| {
let mut results = Vec::new();
let mut results = Vec::with_capacity(state.size as _);
{
let (first, residual) = state;
let lut = if let Some(residual) = residual {
let lut = if let Some(residual) = state.residual {
&O::Vector::compute_lut_block(residual.as_borrowed())
} else {
default_lut_block.as_ref().unwrap()
};
read_h1_tape(
relation.clone(),
first,
state.first,
|| {
RAccess::new(
(&lut.4, (lut.0, lut.1, lut.2, lut.3, 1.9f32)),
O::Distance::block_accessor(),
)
},
|lowerbound, mean, first| {
results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(first)));
|lowerbound, mean, first, size| {
results.push((
Reverse(lowerbound),
AlwaysEqual(mean),
AlwaysEqual(first),
AlwaysEqual(size),
));
},
);
}
let mut heap = BinaryHeap::from(results);
let mut cache = BinaryHeap::<(Reverse<Distance>, _, _)>::new();
let mut cache = BinaryHeap::<(Reverse<Distance>, _)>::new();
{
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap();
let (_, AlwaysEqual(mean), AlwaysEqual(first), AlwaysEqual(size)) =
heap.pop().unwrap();
if is_residual {
let (dis_u, residual_u) = vectors::vector_access_1::<O, _>(
relation.clone(),
Expand All @@ -101,8 +119,11 @@ pub fn insert<O: Operator>(
);
cache.push((
Reverse(dis_u),
AlwaysEqual(first),
AlwaysEqual(Some(residual_u)),
AlwaysEqual(State {
residual: Some(residual_u),
first,
size,
}),
));
} else {
let dis_u = vectors::vector_access_1::<O, _>(
Expand All @@ -113,19 +134,25 @@ pub fn insert<O: Operator>(
O::DistanceAccessor::default(),
),
);
cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(None)));
cache.push((
Reverse(dis_u),
AlwaysEqual(State {
residual: None,
first,
size,
}),
));
}
}
let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop().unwrap();
(first, mean)
let (_, AlwaysEqual(state)) = cache.pop().unwrap();
state
}
};
for _ in (1..height_of_root).rev() {
state = step(state);
}

let (first, residual) = state;
let code = if let Some(residual) = residual {
let code = if let Some(residual) = state.residual {
O::Vector::code(residual.as_borrowed())
} else {
O::Vector::code(vector.as_borrowed())
Expand All @@ -140,7 +167,7 @@ pub fn insert<O: Operator>(
elements: rabitq::pack_to_u64(&code.signs),
});

let jump_guard = relation.read(first);
let jump_guard = relation.read(state.first);
let jump_tuple = jump_guard
.get(1)
.expect("data corruption")
Expand Down
42 changes: 24 additions & 18 deletions src/algorithm/prewarm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub fn prewarm<O: Operator>(relation: impl RelationRead + Clone, height: i32) ->
let height_of_root = meta_tuple.height_of_root();
let root_mean = meta_tuple.root_mean();
let root_first = meta_tuple.root_first();
let root_size = meta_tuple.root_size();
drop(meta_guard);

let mut message = String::new();
Expand All @@ -19,25 +20,27 @@ pub fn prewarm<O: Operator>(relation: impl RelationRead + Clone, height: i32) ->
if prewarm_max_height > height_of_root {
return message;
}
type State = Vec<u32>;
let mut state: State = {
let mut nodes = Vec::new();
{
vectors::vector_access_1::<O, _>(relation.clone(), root_mean, ());
nodes.push(root_first);
}
struct State {
first: u32,
size: u32,
}
let mut states: Vec<State> = {
vectors::vector_access_1::<O, _>(relation.clone(), root_mean, ());
writeln!(message, "------------------------").unwrap();
writeln!(message, "number of nodes: {}", nodes.len()).unwrap();
writeln!(message, "number of nodes: {}", 1).unwrap();
writeln!(message, "number of tuples: {}", 1).unwrap();
writeln!(message, "number of pages: {}", 1).unwrap();
nodes
vec![State {
first: root_first,
size: root_size,
}]
};
let mut step = |state: State| {
let mut step = |states: Vec<State>| {
let mut counter_pages = 0_usize;
let mut counter_tuples = 0_usize;
let mut nodes = Vec::new();
for list in state {
let mut current = list;
let mut nodes = Vec::with_capacity(states.iter().map(|x| x.size).sum::<u32>() as _);
for state in states {
let mut current = state.first;
while current != u32::MAX {
counter_pages += 1;
pgrx::check_for_interrupts!();
Expand All @@ -53,8 +56,11 @@ pub fn prewarm<O: Operator>(relation: impl RelationRead + Clone, height: i32) ->
for mean in h1_tuple.mean().iter().copied() {
vectors::vector_access_1::<O, _>(relation.clone(), mean, ());
}
for first in h1_tuple.first().iter().copied() {
nodes.push(first);
for j in 0..h1_tuple.len() {
nodes.push(State {
first: h1_tuple.first()[j as usize],
size: h1_tuple.size()[j as usize],
});
}
}
H1TupleReader::_1(_) => (),
Expand All @@ -70,14 +76,14 @@ pub fn prewarm<O: Operator>(relation: impl RelationRead + Clone, height: i32) ->
nodes
};
for _ in (std::cmp::max(1, prewarm_max_height)..height_of_root).rev() {
state = step(state);
states = step(states);
}
if prewarm_max_height == 0 {
let mut counter_pages = 0_usize;
let mut counter_tuples = 0_usize;
let mut counter_nodes = 0_usize;
for list in state {
let jump_guard = relation.read(list);
for state in states {
let jump_guard = relation.read(state.first);
let jump_tuple = jump_guard
.get(1)
.expect("data corruption")
Expand Down
73 changes: 51 additions & 22 deletions src/algorithm/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub fn scan<O: Operator>(
assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes");
let root_mean = meta_tuple.root_mean();
let root_first = meta_tuple.root_first();
let root_size = meta_tuple.root_size();
drop(meta_guard);

let default_lut = if !is_residual {
Expand All @@ -37,8 +38,12 @@ pub fn scan<O: Operator>(
None
};

type State<O> = Vec<(u32, Option<<O as Operator>::Vector>)>;
let mut state: State<O> = vec![{
struct State<O: Operator> {
residual: Option<O::Vector>,
first: u32,
size: u32,
}
let mut states: Vec<State<O>> = vec![{
let mean = root_mean;
if is_residual {
let residual_u = vectors::vector_access_1::<O, _>(
Expand All @@ -49,38 +54,52 @@ pub fn scan<O: Operator>(
O::ResidualAccessor::default(),
),
);
(root_first, Some(residual_u))
State {
residual: Some(residual_u),
first: root_first,
size: root_size,
}
} else {
(root_first, None)
State {
residual: None,
first: root_first,
size: root_size,
}
}
}];
let step = |state: State<O>, probes| {
let mut results = Vec::new();
for (first, residual) in state {
let lut = if let Some(residual) = residual {
let step = |states: Vec<State<O>>, probes| {
let mut results = Vec::with_capacity(states.iter().map(|x| x.size).sum::<u32>() as _);
for state in states {
let lut = if let Some(residual) = state.residual {
&O::Vector::compute_lut_block(residual.as_borrowed())
} else {
default_lut.as_ref().map(|x| &x.0).unwrap()
};
read_h1_tape(
relation.clone(),
first,
state.first,
|| {
RAccess::new(
(&lut.4, (lut.0, lut.1, lut.2, lut.3, epsilon)),
O::Distance::block_accessor(),
)
},
|lowerbound, mean, first| {
results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(first)));
|lowerbound, mean, first, size| {
results.push((
Reverse(lowerbound),
AlwaysEqual(mean),
AlwaysEqual(first),
AlwaysEqual(size),
));
},
);
}
let mut heap = BinaryHeap::from(results);
let mut cache = BinaryHeap::<(Reverse<Distance>, _, _)>::new();
let mut cache = BinaryHeap::<(Reverse<Distance>, _)>::new();
std::iter::from_fn(|| {
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap();
let (_, AlwaysEqual(mean), AlwaysEqual(first), AlwaysEqual(size)) =
heap.pop().unwrap();
if is_residual {
let (dis_u, residual_u) = vectors::vector_access_1::<O, _>(
relation.clone(),
Expand All @@ -95,8 +114,11 @@ pub fn scan<O: Operator>(
);
cache.push((
Reverse(dis_u),
AlwaysEqual(first),
AlwaysEqual(Some(residual_u)),
AlwaysEqual(State {
residual: Some(residual_u),
first,
size,
}),
));
} else {
let dis_u = vectors::vector_access_1::<O, _>(
Expand All @@ -107,27 +129,34 @@ pub fn scan<O: Operator>(
O::DistanceAccessor::default(),
),
);
cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(None)));
cache.push((
Reverse(dis_u),
AlwaysEqual(State {
residual: None,
first,
size,
}),
));
}
}
let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop()?;
Some((first, mean))
let (_, AlwaysEqual(state)) = cache.pop()?;
Some(state)
})
.take(probes as usize)
.collect()
};
for i in (1..height_of_root).rev() {
state = step(state, probes[i as usize - 1]);
states = step(states, probes[i as usize - 1]);
}

let mut results = Vec::new();
for (first, residual) in state {
let lut = if let Some(residual) = residual.as_ref().map(|x| x.as_borrowed()) {
for state in states {
let lut = if let Some(residual) = state.residual.as_ref().map(|x| x.as_borrowed()) {
&O::Vector::compute_lut(residual)
} else {
default_lut.as_ref().unwrap()
};
let jump_guard = relation.read(first);
let jump_guard = relation.read(state.first);
let jump_tuple = jump_guard
.get(1)
.expect("data corruption")
Expand Down
Loading

0 comments on commit 7c58c3b

Please sign in to comment.