Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: reserve capacity #176

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/algorithm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pub fn build<O: Operator, T: HeapRelation<O>, R: Reporter>(
factor_err: code.factor_err,
signs: code.signs,
first: pointer_of_firsts[i - 1][child as usize],
size: structures[i - 1].len() as _,
});
}
let tape = tape.into_inner();
Expand All @@ -144,6 +145,7 @@ pub fn build<O: Operator, T: HeapRelation<O>, R: Reporter>(
vectors_first: vectors.first(),
root_mean: pointer_of_means.last().unwrap()[0],
root_first: pointer_of_firsts.last().unwrap()[0],
root_size: structures.last().unwrap().len() as _,
freepage_first: freepage.first(),
});
}
Expand Down
65 changes: 46 additions & 19 deletions src/algorithm/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub fn insert<O: Operator>(
assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions");
let root_mean = meta_tuple.root_mean();
let root_first = meta_tuple.root_first();
let root_size = meta_tuple.root_size();
let vectors_first = meta_tuple.vectors_first();
drop(meta_guard);

Expand All @@ -42,7 +43,11 @@ pub fn insert<O: Operator>(
payload,
);

type State<O> = (u32, Option<<O as Operator>::Vector>);
struct State<O: Operator> {
first: u32,
residual: Option<O::Vector>,
size: u32,
}
let mut state: State<O> = {
let mean = root_mean;
if is_residual {
Expand All @@ -54,39 +59,52 @@ pub fn insert<O: Operator>(
O::ResidualAccessor::default(),
),
);
(root_first, Some(residual_u))
State {
residual: Some(residual_u),
first: root_first,
size: root_size,
}
} else {
(root_first, None)
State {
residual: None,
first: root_first,
size: root_size,
}
}
};
let step = |state: State<O>| {
let mut results = Vec::new();
let mut results = Vec::with_capacity(state.size as _);
{
let (first, residual) = state;
let lut = if let Some(residual) = residual {
let lut = if let Some(residual) = state.residual {
&O::Vector::compute_lut_block(residual.as_borrowed())
} else {
default_lut_block.as_ref().unwrap()
};
read_h1_tape(
relation.clone(),
first,
state.first,
|| {
RAccess::new(
(&lut.4, (lut.0, lut.1, lut.2, lut.3, 1.9f32)),
O::Distance::block_accessor(),
)
},
|lowerbound, mean, first| {
results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(first)));
|lowerbound, mean, first, size| {
results.push((
Reverse(lowerbound),
AlwaysEqual(mean),
AlwaysEqual(first),
AlwaysEqual(size),
));
},
);
}
let mut heap = BinaryHeap::from(results);
let mut cache = BinaryHeap::<(Reverse<Distance>, _, _)>::new();
let mut cache = BinaryHeap::<(Reverse<Distance>, _)>::new();
{
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap();
let (_, AlwaysEqual(mean), AlwaysEqual(first), AlwaysEqual(size)) =
heap.pop().unwrap();
if is_residual {
let (dis_u, residual_u) = vectors::vector_access_1::<O, _>(
relation.clone(),
Expand All @@ -101,8 +119,11 @@ pub fn insert<O: Operator>(
);
cache.push((
Reverse(dis_u),
AlwaysEqual(first),
AlwaysEqual(Some(residual_u)),
AlwaysEqual(State {
residual: Some(residual_u),
first,
size,
}),
));
} else {
let dis_u = vectors::vector_access_1::<O, _>(
Expand All @@ -113,19 +134,25 @@ pub fn insert<O: Operator>(
O::DistanceAccessor::default(),
),
);
cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(None)));
cache.push((
Reverse(dis_u),
AlwaysEqual(State {
residual: None,
first,
size,
}),
));
}
}
let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop().unwrap();
(first, mean)
let (_, AlwaysEqual(state)) = cache.pop().unwrap();
state
}
};
for _ in (1..height_of_root).rev() {
state = step(state);
}

let (first, residual) = state;
let code = if let Some(residual) = residual {
let code = if let Some(residual) = state.residual {
O::Vector::code(residual.as_borrowed())
} else {
O::Vector::code(vector.as_borrowed())
Expand All @@ -140,7 +167,7 @@ pub fn insert<O: Operator>(
elements: rabitq::pack_to_u64(&code.signs),
});

let jump_guard = relation.read(first);
let jump_guard = relation.read(state.first);
let jump_tuple = jump_guard
.get(1)
.expect("data corruption")
Expand Down
42 changes: 24 additions & 18 deletions src/algorithm/prewarm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub fn prewarm<O: Operator>(relation: impl RelationRead + Clone, height: i32) ->
let height_of_root = meta_tuple.height_of_root();
let root_mean = meta_tuple.root_mean();
let root_first = meta_tuple.root_first();
let root_size = meta_tuple.root_size();
drop(meta_guard);

let mut message = String::new();
Expand All @@ -19,25 +20,27 @@ pub fn prewarm<O: Operator>(relation: impl RelationRead + Clone, height: i32) ->
if prewarm_max_height > height_of_root {
return message;
}
type State = Vec<u32>;
let mut state: State = {
let mut nodes = Vec::new();
{
vectors::vector_access_1::<O, _>(relation.clone(), root_mean, ());
nodes.push(root_first);
}
struct State {
first: u32,
size: u32,
}
let mut states: Vec<State> = {
vectors::vector_access_1::<O, _>(relation.clone(), root_mean, ());
writeln!(message, "------------------------").unwrap();
writeln!(message, "number of nodes: {}", nodes.len()).unwrap();
writeln!(message, "number of nodes: {}", 1).unwrap();
writeln!(message, "number of tuples: {}", 1).unwrap();
writeln!(message, "number of pages: {}", 1).unwrap();
nodes
vec![State {
first: root_first,
size: root_size,
}]
};
let mut step = |state: State| {
let mut step = |states: Vec<State>| {
let mut counter_pages = 0_usize;
let mut counter_tuples = 0_usize;
let mut nodes = Vec::new();
for list in state {
let mut current = list;
let mut nodes = Vec::with_capacity(states.iter().map(|x| x.size).sum::<u32>() as _);
for state in states {
let mut current = state.first;
while current != u32::MAX {
counter_pages += 1;
pgrx::check_for_interrupts!();
Expand All @@ -53,8 +56,11 @@ pub fn prewarm<O: Operator>(relation: impl RelationRead + Clone, height: i32) ->
for mean in h1_tuple.mean().iter().copied() {
vectors::vector_access_1::<O, _>(relation.clone(), mean, ());
}
for first in h1_tuple.first().iter().copied() {
nodes.push(first);
for j in 0..h1_tuple.len() {
nodes.push(State {
first: h1_tuple.first()[j as usize],
size: h1_tuple.size()[j as usize],
});
}
}
H1TupleReader::_1(_) => (),
Expand All @@ -70,14 +76,14 @@ pub fn prewarm<O: Operator>(relation: impl RelationRead + Clone, height: i32) ->
nodes
};
for _ in (std::cmp::max(1, prewarm_max_height)..height_of_root).rev() {
state = step(state);
states = step(states);
}
if prewarm_max_height == 0 {
let mut counter_pages = 0_usize;
let mut counter_tuples = 0_usize;
let mut counter_nodes = 0_usize;
for list in state {
let jump_guard = relation.read(list);
for state in states {
let jump_guard = relation.read(state.first);
let jump_tuple = jump_guard
.get(1)
.expect("data corruption")
Expand Down
73 changes: 51 additions & 22 deletions src/algorithm/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub fn scan<O: Operator>(
assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes");
let root_mean = meta_tuple.root_mean();
let root_first = meta_tuple.root_first();
let root_size = meta_tuple.root_size();
drop(meta_guard);

let default_lut = if !is_residual {
Expand All @@ -37,8 +38,12 @@ pub fn scan<O: Operator>(
None
};

type State<O> = Vec<(u32, Option<<O as Operator>::Vector>)>;
let mut state: State<O> = vec![{
struct State<O: Operator> {
residual: Option<O::Vector>,
first: u32,
size: u32,
}
let mut states: Vec<State<O>> = vec![{
let mean = root_mean;
if is_residual {
let residual_u = vectors::vector_access_1::<O, _>(
Expand All @@ -49,38 +54,52 @@ pub fn scan<O: Operator>(
O::ResidualAccessor::default(),
),
);
(root_first, Some(residual_u))
State {
residual: Some(residual_u),
first: root_first,
size: root_size,
}
} else {
(root_first, None)
State {
residual: None,
first: root_first,
size: root_size,
}
}
}];
let step = |state: State<O>, probes| {
let mut results = Vec::new();
for (first, residual) in state {
let lut = if let Some(residual) = residual {
let step = |states: Vec<State<O>>, probes| {
let mut results = Vec::with_capacity(states.iter().map(|x| x.size).sum::<u32>() as _);
for state in states {
let lut = if let Some(residual) = state.residual {
&O::Vector::compute_lut_block(residual.as_borrowed())
} else {
default_lut.as_ref().map(|x| &x.0).unwrap()
};
read_h1_tape(
relation.clone(),
first,
state.first,
|| {
RAccess::new(
(&lut.4, (lut.0, lut.1, lut.2, lut.3, epsilon)),
O::Distance::block_accessor(),
)
},
|lowerbound, mean, first| {
results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(first)));
|lowerbound, mean, first, size| {
results.push((
Reverse(lowerbound),
AlwaysEqual(mean),
AlwaysEqual(first),
AlwaysEqual(size),
));
},
);
}
let mut heap = BinaryHeap::from(results);
let mut cache = BinaryHeap::<(Reverse<Distance>, _, _)>::new();
let mut cache = BinaryHeap::<(Reverse<Distance>, _)>::new();
std::iter::from_fn(|| {
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap();
let (_, AlwaysEqual(mean), AlwaysEqual(first), AlwaysEqual(size)) =
heap.pop().unwrap();
if is_residual {
let (dis_u, residual_u) = vectors::vector_access_1::<O, _>(
relation.clone(),
Expand All @@ -95,8 +114,11 @@ pub fn scan<O: Operator>(
);
cache.push((
Reverse(dis_u),
AlwaysEqual(first),
AlwaysEqual(Some(residual_u)),
AlwaysEqual(State {
residual: Some(residual_u),
first,
size,
}),
));
} else {
let dis_u = vectors::vector_access_1::<O, _>(
Expand All @@ -107,27 +129,34 @@ pub fn scan<O: Operator>(
O::DistanceAccessor::default(),
),
);
cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(None)));
cache.push((
Reverse(dis_u),
AlwaysEqual(State {
residual: None,
first,
size,
}),
));
}
}
let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop()?;
Some((first, mean))
let (_, AlwaysEqual(state)) = cache.pop()?;
Some(state)
})
.take(probes as usize)
.collect()
};
for i in (1..height_of_root).rev() {
state = step(state, probes[i as usize - 1]);
states = step(states, probes[i as usize - 1]);
}

let mut results = Vec::new();
for (first, residual) in state {
let lut = if let Some(residual) = residual.as_ref().map(|x| x.as_borrowed()) {
for state in states {
let lut = if let Some(residual) = state.residual.as_ref().map(|x| x.as_borrowed()) {
&O::Vector::compute_lut(residual)
} else {
default_lut.as_ref().unwrap()
};
let jump_guard = relation.read(first);
let jump_guard = relation.read(state.first);
let jump_tuple = jump_guard
.get(1)
.expect("data corruption")
Expand Down
Loading
Loading