Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid divide-by-zero when training an index with a large dimension #3426

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/lance-io/src/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ impl ObjectStore {
Self {
inner: Arc::new(InMemory::new()).traced(),
scheme: String::from("memory"),
block_size: 64 * 1024,
block_size: 4 * 1024,
use_constant_size_upload_parts: false,
list_is_lexically_ordered: true,
io_parallelism: get_num_compute_intensive_cpus(),
Expand Down Expand Up @@ -977,7 +977,7 @@ async fn configure_store(
"memory" => Ok(ObjectStore {
inner: Arc::new(InMemory::new()).traced(),
scheme: String::from("memory"),
block_size: cloud_block_size,
block_size: file_block_size,
use_constant_size_upload_parts: false,
list_is_lexically_ordered: true,
io_parallelism: get_num_compute_intensive_cpus(),
Expand Down
114 changes: 88 additions & 26 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2223,42 +2223,102 @@ mod tests {
.await;
}

struct TestPqParams {
num_sub_vectors: usize,
num_bits: usize,
}

impl TestPqParams {
fn small() -> Self {
Self {
num_sub_vectors: 2,
num_bits: 8,
}
}
}

// Clippy doesn't like that all start with Ivf but we might have some in the future
// that _don't_ start with Ivf so I feel it is meaningful to keep the prefix
#[allow(clippy::enum_variant_names)]
enum TestIndexType {
IvfPq { pq: TestPqParams },
IvfHnswPq { pq: TestPqParams, num_edges: usize },
IvfHnswSq { num_edges: usize },
IvfFlat,
}

struct CreateIndexCase {
metric_type: MetricType,
num_partitions: usize,
dimension: usize,
index_type: TestIndexType,
}

// We test L2 and Dot, because L2 PQ uses residuals while Dot doesn't,
// so they have slightly different code paths.
#[tokio::test]
#[rstest]
#[case::ivf_pq_l2(VectorIndexParams::with_ivf_pq_params(
MetricType::L2,
IvfBuildParams::new(2),
PQBuildParams::new(2, 8),
))]
#[case::ivf_pq_dot(VectorIndexParams::with_ivf_pq_params(
MetricType::Dot,
IvfBuildParams::new(2),
PQBuildParams::new(2, 8),
))]
#[case::ivf_flat(VectorIndexParams::ivf_flat(1, MetricType::Dot))]
#[case::ivf_hnsw_pq(VectorIndexParams::with_ivf_hnsw_pq_params(
MetricType::Dot,
IvfBuildParams::new(2),
HnswBuildParams::default().num_edges(100),
PQBuildParams::new(2, 8)
))]
#[case::ivf_hnsw_sq(VectorIndexParams::with_ivf_hnsw_sq_params(
MetricType::Dot,
IvfBuildParams::new(2),
HnswBuildParams::default().num_edges(100),
SQBuildParams::default()
))]
#[case::ivf_pq_l2(CreateIndexCase {
metric_type: MetricType::L2,
num_partitions: 2,
dimension: 16,
index_type: TestIndexType::IvfPq { pq: TestPqParams::small() },
})]
#[case::ivf_pq_dot(CreateIndexCase {
metric_type: MetricType::Dot,
num_partitions: 2,
dimension: 2000,
index_type: TestIndexType::IvfPq { pq: TestPqParams::small() },
})]
#[case::ivf_flat(CreateIndexCase { num_partitions: 1, metric_type: MetricType::Dot, dimension: 16, index_type: TestIndexType::IvfFlat })]
#[case::ivf_hnsw_pq(CreateIndexCase {
num_partitions: 2,
metric_type: MetricType::Dot,
dimension: 16,
index_type: TestIndexType::IvfHnswPq { pq: TestPqParams::small(), num_edges: 100 },
})]
#[case::ivf_hnsw_sq(CreateIndexCase {
metric_type: MetricType::Dot,
num_partitions: 2,
dimension: 16,
index_type: TestIndexType::IvfHnswSq { num_edges: 100 },
})]
async fn test_create_index_nulls(
#[case] mut index_params: VectorIndexParams,
#[case] test_case: CreateIndexCase,
#[values(IndexFileVersion::Legacy, IndexFileVersion::V3)] index_version: IndexFileVersion,
) {
let mut index_params = match test_case.index_type {
TestIndexType::IvfPq { pq } => VectorIndexParams::with_ivf_pq_params(
test_case.metric_type,
IvfBuildParams::new(test_case.num_partitions),
PQBuildParams::new(pq.num_sub_vectors, pq.num_bits),
),
TestIndexType::IvfHnswPq { pq, num_edges } => {
VectorIndexParams::with_ivf_hnsw_pq_params(
test_case.metric_type,
IvfBuildParams::new(test_case.num_partitions),
HnswBuildParams::default().num_edges(num_edges),
PQBuildParams::new(pq.num_sub_vectors, pq.num_bits),
)
}
TestIndexType::IvfFlat => {
VectorIndexParams::ivf_flat(test_case.num_partitions, test_case.metric_type)
}
TestIndexType::IvfHnswSq { num_edges } => VectorIndexParams::with_ivf_hnsw_sq_params(
test_case.metric_type,
IvfBuildParams::new(test_case.num_partitions),
HnswBuildParams::default().num_edges(num_edges),
SQBuildParams::default(),
),
};
index_params.version(index_version);

let nrows = 2_000;
let data = gen()
.col("vec", array::rand_vec::<Float32Type>(Dimension::from(16)))
.col(
"vec",
array::rand_vec::<Float32Type>(Dimension::from(test_case.dimension as u32)),
)
.into_batch_rows(RowCount::from(nrows))
.unwrap();

Expand Down Expand Up @@ -2287,7 +2347,9 @@ mod tests {
.await
.unwrap();

let query = vec![0.0; 16].into_iter().collect::<Float32Array>();
let query = vec![0.0; test_case.dimension]
.into_iter()
.collect::<Float32Array>();
let results = dataset
.scan()
.nearest("vec", &query, 2_000)
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ fn random_ranges(
block_size: usize,
byte_width: usize,
) -> impl Iterator<Item = std::ops::Range<u64>> + Send {
let rows_per_batch = block_size / byte_width;
let rows_per_batch = 1.max(block_size / byte_width);
let mut rng = SmallRng::from_entropy();
let num_bins = num_rows.div_ceil(rows_per_batch);

Expand Down
Loading