Skip to content

Commit

Permalink
fix: filter out null values when sampling for index training (#3404)
Browse files Browse the repository at this point in the history
We were not filtering out null values when sampling. Because we often
call `array.values()` on Arrow arrays, which ignores the null bitmap, we
are often silently treating the nulls as zeros (or possibly undefined
values). Only thing that caught these nulls is an assertion. However,
residualization occurring with L2 and Cosine often meant that these
values were transformed and null information was lost before the
assertion, which is why it got past previous unit tests.

This PR adds more assertions validating there aren't nulls, and makes
sure the sampling code handles null vectors.

Closes #3402
Closes #3400
  • Loading branch information
wjones127 authored Jan 28, 2025
1 parent 6d77d14 commit bfacd7c
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 7 deletions.
3 changes: 3 additions & 0 deletions rust/lance/src/index/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ pub async fn merge_indices<'a>(
.with_fragments(unindexed)
.with_row_id()
.project(&[&column.name])?;
if column.nullable {
scanner.filter_expr(datafusion_expr::col(&column.name).is_not_null());
}
Some(scanner.try_into_stream().await?)
};

Expand Down
20 changes: 15 additions & 5 deletions rust/lance/src/index/vector/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,23 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
"dataset not set before shuffling",
location!(),
))?;
let stream = dataset
.scan()
let is_nullable = dataset
.schema()
.field(&self.column)
.ok_or(Error::invalid_input(
format!("column {} not found in dataset", self.column).as_str(),
location!(),
))?
.nullable;
let mut builder = dataset.scan();
builder
.batch_readahead(get_num_compute_intensive_cpus())
.project(&[self.column.as_str()])?
.with_row_id()
.try_into_stream()
.await?;
.with_row_id();
if is_nullable {
builder.filter_expr(datafusion_expr::col(&self.column).is_not_null());
}
let stream = builder.try_into_stream().await?;
self.shuffle_data(Some(stream)).await?;
Ok(())
}
Expand Down
87 changes: 86 additions & 1 deletion rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1741,21 +1741,28 @@ mod tests {
use std::ops::Range;

use arrow_array::types::UInt64Type;
use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array};
use arrow_array::{
make_array, Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array,
};
use arrow_buffer::{BooleanBuffer, NullBuffer};
use arrow_schema::Field;
use itertools::Itertools;
use lance_core::utils::address::RowAddress;
use lance_core::ROW_ID;
use lance_datagen::{array, gen, Dimension, RowCount};
use lance_index::vector::sq::builder::SQBuildParams;
use lance_linalg::distance::l2_distance_batch;
use lance_testing::datagen::{
generate_random_array, generate_random_array_with_range, generate_random_array_with_seed,
generate_scaled_random_array, sample_without_replacement,
};
use rand::{seq::SliceRandom, thread_rng};
use rstest::rstest;
use tempfile::tempdir;

use crate::dataset::InsertBuilder;
use crate::index::prefilter::DatasetPreFilter;
use crate::index::vector::IndexFileVersion;
use crate::index::vector_index_details;
use crate::index::{vector::VectorIndexParams, DatasetIndexExt, DatasetIndexInternalExt};

Expand Down Expand Up @@ -2215,6 +2222,84 @@ mod tests {
.await;
}

// We test L2 and Dot, because L2 PQ uses residuals while Dot doesn't,
// so they have slightly different code paths.
#[tokio::test]
#[rstest]
#[case::ivf_pq_l2(VectorIndexParams::with_ivf_pq_params(
MetricType::L2,
IvfBuildParams::new(2),
PQBuildParams::new(2, 8),
))]
#[case::ivf_pq_dot(VectorIndexParams::with_ivf_pq_params(
MetricType::Dot,
IvfBuildParams::new(2),
PQBuildParams::new(2, 8),
))]
#[case::ivf_flat(VectorIndexParams::ivf_flat(1, MetricType::Dot))]
#[case::ivf_hnsw_pq(VectorIndexParams::with_ivf_hnsw_pq_params(
MetricType::Dot,
IvfBuildParams::new(2),
HnswBuildParams::default().num_edges(100),
PQBuildParams::new(2, 8)
))]
#[case::ivf_hnsw_sq(VectorIndexParams::with_ivf_hnsw_sq_params(
MetricType::Dot,
IvfBuildParams::new(2),
HnswBuildParams::default().num_edges(100),
SQBuildParams::default()
))]
async fn test_create_index_nulls(
#[case] mut index_params: VectorIndexParams,
#[values(IndexFileVersion::Legacy, IndexFileVersion::V3)] index_version: IndexFileVersion,
) {
index_params.version(index_version);

let nrows = 2_000;
let data = gen()
.col("vec", array::rand_vec::<Float32Type>(Dimension::from(16)))
.into_batch_rows(RowCount::from(nrows))
.unwrap();

// Make every other row null
let null_buffer = (0..nrows).map(|i| i % 2 == 0).collect::<BooleanBuffer>();
let null_buffer = NullBuffer::new(null_buffer);
let vectors = data["vec"]
.clone()
.to_data()
.into_builder()
.nulls(Some(null_buffer))
.build()
.unwrap();
let vectors = make_array(vectors);
let num_non_null = vectors.len() - vectors.logical_null_count();
let data = RecordBatch::try_new(data.schema(), vec![vectors]).unwrap();

let mut dataset = InsertBuilder::new("memory://")
.execute(vec![data])
.await
.unwrap();

// Create index
dataset
.create_index(&["vec"], IndexType::Vector, None, &index_params, false)
.await
.unwrap();

let query = vec![0.0; 16].into_iter().collect::<Float32Array>();
let results = dataset
.scan()
.nearest("vec", &query, 2_000)
.unwrap()
.ef(100_000)
.nprobs(2)
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), num_non_null);
assert_eq!(results["vec"].logical_null_count(), 0);
}

#[tokio::test]
async fn test_create_ivf_pq_cosine() {
let test_dir = tempdir().unwrap();
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ pub async fn build_pq_model(
"Finished loading training data in {:02} seconds",
start.elapsed().as_secs_f32()
);
assert_eq!(training_data.logical_null_count(), 0);

info!(
"starting to compute partitions for PQ training, sample size: {}",
Expand Down
178 changes: 177 additions & 1 deletion rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
use std::sync::Arc;

use arrow_array::{cast::AsArray, FixedSizeListArray};
use futures::StreamExt;
use lance_arrow::{interleave_batches, DataTypeExt};
use lance_core::datatypes::Schema;
use log::info;
use rand::rngs::SmallRng;
use rand::seq::{IteratorRandom, SliceRandom};
use rand::SeedableRng;
use snafu::{location, Location};
use tokio::sync::Mutex;

Expand Down Expand Up @@ -107,17 +112,93 @@ pub async fn maybe_sample_training_data(
sample_size_hint: usize,
) -> Result<FixedSizeListArray> {
let num_rows = dataset.count_rows(None).await?;
let batch = if num_rows > sample_size_hint {

let vector_field = dataset.schema().field(column).ok_or(Error::Index {
message: format!(
"Sample training data: column {} does not exist in schema",
column
),
location: location!(),
})?;
let is_nullable = vector_field.nullable;

let batch = if num_rows > sample_size_hint && !is_nullable {
let projection = dataset.schema().project(&[column])?;
let batch = dataset.sample(sample_size_hint, &projection).await?;
info!(
"Sample training data: retrieved {} rows by sampling",
batch.num_rows()
);
batch
} else if num_rows > sample_size_hint && is_nullable {
// Use min block size + vector size to determine sample granularity
// For example, on object storage, block size is 64 KB. A 768-dim 32-bit
// vector is 3 KB. So we can sample every 64 KB / 3 KB = 21 vectors.
let block_size = dataset.object_store().block_size();
// We provide a fallback in case of multi-vector, which will have
// a variable size. We use 4 KB as a fallback.
let byte_width = vector_field
.data_type()
.byte_width_opt()
.unwrap_or(4 * 1024);

let ranges = random_ranges(num_rows, sample_size_hint, block_size, byte_width);

let mut collected = Vec::with_capacity(ranges.size_hint().0);
let mut indices = Vec::with_capacity(sample_size_hint);
let mut num_non_null = 0;

let mut scan = dataset.take_scan(
Box::pin(futures::stream::iter(ranges).map(Ok)),
Arc::new(dataset.schema().project(&[column])?),
dataset.object_store().io_parallelism(),
);

while let Some(batch) = scan.next().await {
let batch = batch?;

let array = batch.column_by_name(column).ok_or(Error::Index {
message: format!(
"Sample training data: column {} does not exist in return",
column
),
location: location!(),
})?;
let null_count = array.logical_null_count();
if null_count < array.len() {
num_non_null += array.len() - null_count;

let batch_i = collected.len();
if let Some(null_buffer) = array.nulls() {
for i in null_buffer.valid_indices() {
indices.push((batch_i, i));
}
} else {
indices.extend((0..array.len()).map(|i| (batch_i, i)));
}

collected.push(batch);
}
if num_non_null >= sample_size_hint {
break;
}
}

let batch = interleave_batches(&collected, &indices).map_err(|err| Error::Index {
message: format!("Sample training data: {}", err),
location: location!(),
})?;
info!(
"Sample training data: retrieved {} rows by sampling after filtering out nulls",
batch.num_rows()
);
batch
} else {
let mut scanner = dataset.scan();
scanner.project(&[column])?;
if is_nullable {
scanner.filter_expr(datafusion_expr::col(column).is_not_null());
}
let batch = scanner.try_into_batch().await?;
info!(
"Sample training data: retrieved {} rows scanning full datasets",
Expand Down Expand Up @@ -172,3 +253,98 @@ impl PartitionLoadLock {
mtx.clone()
}
}

/// Generate random ranges to sample from a dataset.
///
/// This will return an iterator of ranges that cover the whole dataset. It
/// provides an unbound iterator so that the caller can decide when to stop.
/// This is useful when the caller wants to sample a fixed number of rows, but
/// has an additional filter that must be applied.
///
/// Parameters:
/// * `num_rows`: number of rows in the dataset
/// * `sample_size_hint`: the target number of rows to be sampled in the end.
/// This is a hint for the minimum number of rows that will be consumed, but
/// the caller may consume more than this.
/// * `block_size`: the byte size of ranges that should be used.
/// * `byte_width`: the byte width of the vectors that will be sampled.
fn random_ranges(
num_rows: usize,
sample_size_hint: usize,
block_size: usize,
byte_width: usize,
) -> impl Iterator<Item = std::ops::Range<u64>> + Send {
let rows_per_batch = block_size / byte_width;
let mut rng = SmallRng::from_entropy();
let num_bins = num_rows.div_ceil(rows_per_batch);

let bins_iter: Box<dyn Iterator<Item = usize> + Send> = if sample_size_hint * 5 >= num_rows {
// It's faster to just allocate and shuffle
let mut indices = (0..num_bins).collect::<Vec<_>>();
indices.shuffle(&mut rng);
Box::new(indices.into_iter())
} else {
// If the sample is a small proportion, then we can instead use a set
// to track which bins we have seen. We start by using the sample_size_hint
// to provide an efficient start, and from there we randomly choose bins
// one by one.
let num_bins = num_rows.div_ceil(rows_per_batch);
// Start with the minimum number we will need.
let min_sample_size = sample_size_hint / rows_per_batch;
let starting_bins = (0..num_bins).choose_multiple(&mut rng, min_sample_size);
let mut seen = starting_bins
.iter()
.cloned()
.collect::<std::collections::HashSet<_>>();

let additional = std::iter::from_fn(move || loop {
if seen.len() >= num_bins {
break None;
}
let next = (0..num_bins).choose(&mut rng).unwrap();
if seen.contains(&next) {
continue;
} else {
seen.insert(next);
return Some(next);
}
});

Box::new(starting_bins.into_iter().chain(additional))
};

bins_iter.map(move |i| {
let start = (i * rows_per_batch) as u64;
let end = ((i + 1) * rows_per_batch) as u64;
let end = std::cmp::min(end, num_rows as u64);
start..end
})
}

#[cfg(test)]
mod tests {
use super::*;

#[rstest::rstest]
#[test]
fn test_random_ranges(
#[values(99, 100, 102)] num_rows: usize,
#[values(10, 100)] sample_size: usize,
) {
// We can just assert that the output when sorted is the same as the input
let block_size = 100;
let byte_width = 10;

let bin_size = block_size / byte_width;
assert_eq!(bin_size, 10);

let mut ranges =
random_ranges(num_rows, sample_size, block_size, byte_width).collect::<Vec<_>>();
ranges.sort_by_key(|r| r.start);
let expected = (0..num_rows as u64).step_by(bin_size).map(|start| {
let end = std::cmp::min(start + bin_size as u64, num_rows as u64);
start..end
});
assert_eq!(ranges, expected.collect::<Vec<_>>());
}
}

0 comments on commit bfacd7c

Please sign in to comment.