Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: filter out null values when sampling for index training #3404

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1741,21 +1741,28 @@ mod tests {
use std::ops::Range;

use arrow_array::types::UInt64Type;
use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array};
use arrow_array::{
make_array, Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array,
};
use arrow_buffer::{BooleanBuffer, NullBuffer};
use arrow_schema::Field;
use itertools::Itertools;
use lance_core::utils::address::RowAddress;
use lance_core::ROW_ID;
use lance_datagen::{array, gen, Dimension, RowCount};
use lance_index::vector::sq::builder::SQBuildParams;
use lance_linalg::distance::l2_distance_batch;
use lance_testing::datagen::{
generate_random_array, generate_random_array_with_range, generate_random_array_with_seed,
generate_scaled_random_array, sample_without_replacement,
};
use rand::{seq::SliceRandom, thread_rng};
use rstest::rstest;
use tempfile::tempdir;

use crate::dataset::InsertBuilder;
use crate::index::prefilter::DatasetPreFilter;
use crate::index::vector::IndexFileVersion;
use crate::index::vector_index_details;
use crate::index::{vector::VectorIndexParams, DatasetIndexExt, DatasetIndexInternalExt};

Expand Down Expand Up @@ -2215,6 +2222,85 @@ mod tests {
.await;
}

#[rstest]
#[tokio::test]
async fn test_create_index_nulls(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm thinking should we add some tests for verifying recall? then we can know whether flat search handles nulls well.

it might be good to modify this test https://github.com/lancedb/lance/blob/main/rust/lance/src/index/vector/ivf/v2.rs to contain half rows with nulls

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking, is there a way to count rows that are present in the index? I assume if it’s null then we don’t write it to the index file, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the test so it asserts we can use search to get all the non-null vectors back. But I am not getting the results I expect. I could use your advice to know what the expected behavior of these indices should be when there are lots of null vectors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems no such way to count that now, it could be easy for v3 index by counting the num rows of storage file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could use your advice to know what the expected behavior of these indices should be when there are lots of null vectors.

@BubbleCal Could you help me make sense of the output of this test? https://github.com/lancedb/lance/actions/runs/12918160780/job/36026117407?pr=3404

I was expecting search to only return non-null rows, but it seems like we are getting some null vectors in the results.

// We test L2 and Dot, because L2 PQ uses residuals while Dot doesn't,
// so they have slightly different code paths.
#[values(
VectorIndexParams::with_ivf_pq_params(
MetricType::L2,
IvfBuildParams::new(2),
PQBuildParams::new(2, 4),
),
VectorIndexParams::with_ivf_pq_params(
MetricType::Dot,
IvfBuildParams::new(2),
PQBuildParams::new(2, 4),
),
VectorIndexParams::ivf_flat(1, MetricType::Dot),
VectorIndexParams::with_ivf_hnsw_pq_params(
MetricType::Dot,
IvfBuildParams::new(2),
HnswBuildParams::default(),
PQBuildParams::new(2, 4)
),
VectorIndexParams::with_ivf_hnsw_sq_params(
MetricType::Dot,
IvfBuildParams::new(2),
HnswBuildParams::default(),
SQBuildParams::default()
)
)]
mut index_params: VectorIndexParams,
#[values(IndexFileVersion::Legacy, IndexFileVersion::V3)] index_version: IndexFileVersion,
) {
index_params.version(index_version);

let nrows = 2_000;
let data = gen()
.col("vec", array::rand_vec::<Float32Type>(Dimension::from(16)))
.into_batch_rows(RowCount::from(nrows))
.unwrap();

// Make every other row null
let null_buffer = (0..nrows).map(|i| i % 2 == 0).collect::<BooleanBuffer>();
let null_buffer = NullBuffer::new(null_buffer);
let vectors = data["vec"]
.clone()
.to_data()
.into_builder()
.nulls(Some(null_buffer))
.build()
.unwrap();
let vectors = make_array(vectors);
let num_non_null = vectors.len() - vectors.logical_null_count();
let data = RecordBatch::try_new(data.schema(), vec![vectors]).unwrap();

let mut dataset = InsertBuilder::new("memory://")
.execute(vec![data])
.await
.unwrap();

// Create index
dataset
.create_index(&["vec"], IndexType::Vector, None, &index_params, false)
.await
.unwrap();

let query = vec![0.0; 16].into_iter().collect::<Float32Array>();
let results = dataset
.scan()
.nearest("vec", &query, 2_000)
.unwrap()
.nprobs(2)
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), num_non_null);
assert_eq!(results["vec"].logical_null_count(), 0);
}

#[tokio::test]
async fn test_create_ivf_pq_cosine() {
let test_dir = tempdir().unwrap();
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ pub async fn build_pq_model(
"Finished loading training data in {:02} seconds",
start.elapsed().as_secs_f32()
);
debug_assert_eq!(training_data.logical_null_count(), 0);

info!(
"starting to compute partitions for PQ training, sample size: {}",
Expand Down
59 changes: 57 additions & 2 deletions rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use log::info;
use snafu::{location, Location};
use tokio::sync::Mutex;

use crate::dataset::Dataset;
use crate::dataset::{Dataset, ProjectionRequest, TakeBuilder};
use crate::{Error, Result};

/// Get the vector dimension of the given column in the schema.
Expand Down Expand Up @@ -107,17 +107,72 @@ pub async fn maybe_sample_training_data(
sample_size_hint: usize,
) -> Result<FixedSizeListArray> {
let num_rows = dataset.count_rows(None).await?;
let batch = if num_rows > sample_size_hint {

let is_nullable = dataset
.schema()
.field(column)
.ok_or(Error::Index {
message: format!(
"Sample training data: column {} does not exist in schema",
column
),
location: location!(),
})?
.nullable;

let batch = if num_rows > sample_size_hint && !is_nullable {
let projection = dataset.schema().project(&[column])?;
let batch = dataset.sample(sample_size_hint, &projection).await?;
info!(
"Sample training data: retrieved {} rows by sampling",
batch.num_rows()
);
batch
} else if num_rows > sample_size_hint && is_nullable {
// Need to filter out null values
// Use a scan to collect row ids. Then sample from the row ids. Then do take.
let row_addrs = dataset
.scan()
.filter_expr(datafusion_expr::col(column).is_not_null())
.with_row_address()
.project::<&str>(&[])?
.try_into_batch()
.await?;
debug_assert_eq!(row_addrs.num_columns(), 1);
debug_assert_eq!(row_addrs["_rowaddr"].logical_null_count(), 0);
let row_addrs = row_addrs
.column(0)
.as_any()
.downcast_ref::<arrow::array::UInt64Array>()
.ok_or(Error::Index {
message: format!(
"Sample training data: column {} is not a UInt64Array",
column
),
location: location!(),
})?;

let batch = TakeBuilder::try_new_from_addresses(
Arc::new(dataset.clone()),
row_addrs.values().to_vec(),
Arc::new(
ProjectionRequest::from_columns([column], dataset.schema())
.into_projection_plan(dataset.schema())?,
),
)?
.execute()
.await?;
info!(
"Sample training data: retrieved {} rows by sampling after filtering out nulls",
batch.num_rows()
);
batch
} else {
let mut scanner = dataset.scan();
scanner.project(&[column])?;
if is_nullable {
scanner.filter_expr(datafusion_expr::col(column).is_not_null());
}
let batch = scanner.try_into_batch().await?;
info!(
"Sample training data: retrieved {} rows scanning full datasets",
Expand Down
Loading