Skip to content

Commit

Permalink
feat: support to read IVF partitions (#3462)
Browse files Browse the repository at this point in the history
  • Loading branch information
BubbleCal authored Feb 21, 2025
1 parent cca98fc commit 59b414b
Show file tree
Hide file tree
Showing 13 changed files with 360 additions and 15 deletions.
109 changes: 109 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3841,3 +3841,112 @@ def _validate_metadata(metadata: dict):
)
elif isinstance(v, dict):
_validate_metadata(v)


class VectorIndexReader:
"""
This class allows you to initialize a reader for a specific vector index,
retrieve the number of partitions,
access the centroids of the index,
and read specific partitions of the index.
Parameters
----------
dataset: LanceDataset
The dataset containing the index.
index_name: str
The name of the vector index to read.
Examples
--------
.. code-block:: python
import lance
from lance.dataset import VectorIndexReader
import numpy as np
import pyarrow as pa
vectors = np.random.rand(256, 2)
data = pa.table({"vector": pa.array(vectors.tolist(),
type=pa.list_(pa.float32(), 2))})
dataset = lance.write_dataset(data, "/tmp/index_reader_demo")
dataset.create_index("vector", index_type="IVF_PQ",
num_partitions=4, num_sub_vectors=2)
reader = VectorIndexReader(dataset, "vector_idx")
assert reader.num_partitions() == 4
partition = reader.read_partition(0)
assert "_rowid" in partition.column_names
Exceptions
----------
ValueError
If the specified index is not a vector index.
"""

def __init__(self, dataset: LanceDataset, index_name: str):
stats = dataset.stats.index_stats(index_name)
self.dataset = dataset
self.index_name = index_name
self.stats = stats
try:
self.num_partitions()
except KeyError:
raise ValueError(f"Index {index_name} is not vector index")

def num_partitions(self) -> int:
"""
Returns the number of partitions in the dataset.
Returns
-------
int
The number of partitions.
"""

return self.stats["indices"][0]["num_partitions"]

def centroids(self) -> np.ndarray:
"""
Returns the centroids of the index
Returns
-------
np.ndarray
The centroids of IVF
with shape (num_partitions, dim)
"""
# when we have more delta indices,
# they are with the same centroids
return np.array(
self.dataset._ds.get_index_centroids(self.stats["indices"][0]["centroids"])
)

def read_partition(
self, partition_id: int, *, with_vector: bool = False
) -> pa.Table:
"""
Returns a pyarrow table for the given IVF partition
Parameters
----------
partition_id: int
The id of the partition to read
with_vector: bool, default False
Whether to include the vector column in the reader,
for IVF_PQ, the vector column is PQ codes
Returns
-------
pa.Table
A pyarrow table for the given partition,
containing the row IDs, and quantized vectors (if with_vector is True).
"""

if partition_id < 0 or partition_id >= self.num_partitions():
raise IndexError(
f"Partition id {partition_id} is out of range, "
f"expected 0 <= partition_id < {self.num_partitions()}"
)

return self.dataset._ds.read_index_partition(
self.index_name, partition_id, with_vector
).read_all()
31 changes: 31 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyarrow.compute as pc
import pytest
from lance import LanceFragment
from lance.dataset import VectorIndexReader

torch = pytest.importorskip("torch")
from lance.util import validate_vector_index # noqa: E402
Expand Down Expand Up @@ -1129,3 +1130,33 @@ def test_drop_indices(indexed_dataset):
)

assert len(results) == 15


def test_read_partition(indexed_dataset):
idx_name = indexed_dataset.list_indices()[0]["name"]
reader = VectorIndexReader(indexed_dataset, idx_name)

num_rows = indexed_dataset.count_rows()
row_sum = 0
for part_id in range(reader.num_partitions()):
res = reader.read_partition(part_id)
row_sum += res.num_rows
assert "_rowid" in res.column_names
assert row_sum == num_rows

row_sum = 0
for part_id in range(reader.num_partitions()):
res = reader.read_partition(part_id, with_vector=True)
row_sum += res.num_rows
pq_column = res["__pq_code"]
assert "_rowid" in res.column_names
assert pq_column.type == pa.list_(pa.uint8(), 16)
assert row_sum == num_rows

# error tests
with pytest.raises(IndexError, match="out of range"):
reader.read_partition(reader.num_partitions() + 1)

with pytest.raises(ValueError, match="not vector index"):
indexed_dataset.create_scalar_index("id", index_type="BTREE")
VectorIndexReader(indexed_dataset, "id_idx")
23 changes: 22 additions & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use arrow_array::Array;
use futures::{StreamExt, TryFutureExt};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::refs::{Ref, TagContents};
use lance::dataset::scanner::MaterializationStyle;
use lance::dataset::scanner::{DatasetRecordBatchStream, MaterializationStyle};
use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt};
use lance::dataset::{
fragment::FileFragment as LanceFileFragment,
Expand Down Expand Up @@ -1558,6 +1558,27 @@ impl Dataset {

Ok(())
}

#[pyo3(signature = (index_name,partition_id, with_vector=false))]
fn read_index_partition(
&self,
index_name: String,
partition_id: usize,
with_vector: bool,
) -> PyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let stream = RT
.block_on(
None,
self.ds
.read_index_partition(&index_name, partition_id, with_vector),
)?
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let reader = Box::new(LanceReader::from_stream(DatasetRecordBatchStream::new(
stream,
)));
Ok(PyArrowType(reader))
}
}

impl Dataset {
Expand Down
8 changes: 8 additions & 0 deletions rust/lance-index/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream;
use lance_core::Result;

use crate::{optimize::OptimizeOptions, IndexParams, IndexType};
Expand Down Expand Up @@ -97,4 +98,11 @@ pub trait DatasetIndexExt {
column: &str,
index_id: Uuid,
) -> Result<()>;

async fn read_index_partition(
&self,
index_name: &str,
partition_id: usize,
with_vector: bool,
) -> Result<SendableRecordBatchStream>;
}
13 changes: 13 additions & 0 deletions rust/lance-index/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::{collections::HashMap, sync::Arc};
use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
use arrow_schema::Field;
use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream;
use ivf::storage::IvfModel;
use lance_core::{Result, ROW_ID_FIELD};
use lance_io::object_store::ObjectStore;
Expand Down Expand Up @@ -179,6 +180,18 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index {
self.load(reader, offset, length).await
}

// for IVF only
async fn partition_reader(
&self,
_partition_id: usize,
_with_vector: bool,
) -> Result<SendableRecordBatchStream> {
unimplemented!("only for IVF")
}

// for SubIndex only
async fn to_batch_stream(&self, with_vector: bool) -> Result<SendableRecordBatchStream>;

/// Return the IDs of rows in the index.
fn row_ids(&self) -> Box<dyn Iterator<Item = &'_ u64> + '_>;

Expand Down
30 changes: 30 additions & 0 deletions rust/lance-index/src/vector/hnsw/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ use std::{

use arrow_array::{RecordBatch, UInt32Array};
use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use deepsize::DeepSizeOf;
use lance_arrow::RecordBatchExt;
use lance_core::ROW_ID;
use lance_core::{datatypes::Schema, Error, Result};
use lance_file::reader::FileReader;
use lance_io::traits::Reader;
Expand Down Expand Up @@ -263,6 +267,32 @@ impl<Q: Quantization + Send + Sync + 'static> VectorIndex for HNSWIndex<Q> {
}))
}

async fn to_batch_stream(&self, with_vector: bool) -> Result<SendableRecordBatchStream> {
let store = self.storage.as_ref().ok_or(Error::Index {
message: "vector storage not loaded".to_string(),
location: location!(),
})?;

let schema = if with_vector {
store.schema().clone()
} else {
let schema = store.schema();
let row_id_idx = schema.index_of(ROW_ID)?;
Arc::new(schema.project(&[row_id_idx])?)
};

let batches = store
.to_batches()?
.map(|b| {
let batch = b.project_by_schema(&schema)?;
Ok(batch)
})
.collect::<Vec<_>>();
let stream = futures::stream::iter(batches);
let stream = RecordBatchStreamAdapter::new(schema, stream);
Ok(Box::pin(stream))
}

fn row_ids(&self) -> Box<dyn Iterator<Item = &'_ u64> + '_> {
Box::new(self.storage.as_ref().unwrap().row_ids())
}
Expand Down
56 changes: 47 additions & 9 deletions rust/lance/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, OnceLock};

use arrow_schema::DataType;
use arrow_schema::{DataType, Schema};
use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::{stream, StreamExt, TryStreamExt};
use itertools::Itertools;
use lance_core::utils::parse::str_is_truthy;
Expand Down Expand Up @@ -686,6 +688,48 @@ impl DatasetIndexExt for Dataset {
location: location!(),
})
}

async fn read_index_partition(
&self,
index_name: &str,
partition_id: usize,
with_vector: bool,
) -> Result<SendableRecordBatchStream> {
let indices = self.load_indices_by_name(index_name).await?;
if indices.is_empty() {
return Err(Error::IndexNotFound {
identity: format!("name={}", index_name),
location: location!(),
});
}
let column = self.schema().field_by_id(indices[0].fields[0]).unwrap();

let mut schema: Option<Arc<Schema>> = None;
let mut partition_streams = Vec::with_capacity(indices.len());
for index in indices {
let index = self
.open_vector_index(&column.name, &index.uuid.to_string())
.await?;

let stream = index.partition_reader(partition_id, with_vector).await?;
if schema.is_none() {
schema = Some(stream.schema());
}
partition_streams.push(stream);
}

match schema {
Some(schema) => {
let merged = stream::select_all(partition_streams);
let stream = RecordBatchStreamAdapter::new(schema, merged);
Ok(Box::pin(stream))
}
None => Ok(Box::pin(RecordBatchStreamAdapter::new(
Arc::new(Schema::empty()),
stream::empty(),
))),
}
}
}

/// A trait for internal dataset utilities
Expand Down Expand Up @@ -775,14 +819,8 @@ impl DatasetIndexInternalExt for Dataset {
match &proto.implementation {
Some(Implementation::VectorIndex(vector_index)) => {
let dataset = Arc::new(self.clone());
crate::index::vector::open_vector_index(
dataset,
column,
uuid,
vector_index,
reader,
)
.await
crate::index::vector::open_vector_index(dataset, uuid, vector_index, reader)
.await
}
None => Err(Error::Internal {
message: "Index proto was missing implementation field".into(),
Expand Down
1 change: 0 additions & 1 deletion rust/lance/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ pub(crate) async fn remap_vector_index(
#[instrument(level = "debug", skip(dataset, vec_idx, reader))]
pub(crate) async fn open_vector_index(
dataset: Arc<Dataset>,
column: &str,
uuid: &str,
vec_idx: &lance_index::pb::VectorIndex,
reader: Arc<dyn Reader>,
Expand Down
5 changes: 5 additions & 0 deletions rust/lance/src/index/vector/fixture_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod test {
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream;
use deepsize::{Context, DeepSizeOf};
use lance_arrow::FixedSizeListArrayExt;
use lance_index::vector::ivf::storage::IvfModel;
Expand Down Expand Up @@ -142,6 +143,10 @@ mod test {
Ok(())
}

async fn to_batch_stream(&self, _with_vector: bool) -> Result<SendableRecordBatchStream> {
unimplemented!("only for SubIndex")
}

fn ivf_model(&self) -> IvfModel {
unimplemented!("only for IVF")
}
Expand Down
Loading

0 comments on commit 59b414b

Please sign in to comment.