Skip to content

Commit

Permalink
add python API
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal committed Feb 19, 2025
1 parent dc080cc commit e58e601
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 5 deletions.
7 changes: 4 additions & 3 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3849,6 +3849,7 @@ def __init__(self, dataset: LanceDataset, index_name: str):
stats = dataset.stats.index_stats(index_name)

self.dataset = dataset
self.index_name = index_name
self.stats = stats

def num_partitions(self) -> int:
Expand All @@ -3874,7 +3875,7 @@ def centroids(self) -> np.ndarray:
)

def partition_reader(
self, partition_id: int, with_vector: bool = False
self, partition_id: int, *, with_vector: bool = False
) -> pa.RecordBatchReader:
"""
Returns a reader for the given IVF partition
Expand All @@ -3893,6 +3894,6 @@ def partition_reader(
A record batch reader for the given partition
"""

return self.dataset._ds.partition_reader(
self.stats["indices"][0]["partitions"][partition_id]
return self.dataset._ds.read_index_partition(
self.index_name, partition_id, with_vector
)
14 changes: 14 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyarrow.compute as pc
import pytest
from lance import LanceFragment
from lance.dataset import VectorIndexReader

torch = pytest.importorskip("torch")
from lance.util import validate_vector_index # noqa: E402
Expand Down Expand Up @@ -1129,3 +1130,16 @@ def test_drop_indices(indexed_dataset):
)

assert len(results) == 15


def test_read_partition(indexed_dataset):
idx_name = indexed_dataset.list_indices()[0]["name"]
reader = VectorIndexReader(indexed_dataset, idx_name)

num_rows = indexed_dataset.count_rows()
row_sum = 0
for part_id in range(reader.num_partitions()):
part_reader = reader.partition_reader(part_id)
for batch in part_reader:
row_sum += batch.num_rows
assert row_sum == num_rows
23 changes: 22 additions & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use arrow_array::Array;
use futures::{StreamExt, TryFutureExt};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::refs::{Ref, TagContents};
use lance::dataset::scanner::MaterializationStyle;
use lance::dataset::scanner::{DatasetRecordBatchStream, MaterializationStyle};
use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt};
use lance::dataset::{
fragment::FileFragment as LanceFileFragment,
Expand Down Expand Up @@ -1558,6 +1558,27 @@ impl Dataset {

Ok(())
}

#[pyo3(signature = (index_name,partition_id, with_vector=false))]
fn read_index_partition(
&self,
index_name: String,
partition_id: usize,
with_vector: bool,
) -> PyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let stream = RT
.block_on(
None,
self.ds
.read_index_partition(&index_name, partition_id, with_vector),
)?
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let reader = Box::new(LanceReader::from_stream(DatasetRecordBatchStream::new(
stream,
)));
Ok(PyArrowType(reader))
}
}

impl Dataset {
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/index/vector/fixture_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ mod test {
Ok(())
}

async fn to_batch_stream(&self, with_vector: bool) -> Result<SendableRecordBatchStream> {
async fn to_batch_stream(&self, _with_vector: bool) -> Result<SendableRecordBatchStream> {
unimplemented!("only for SubIndex")
}

Expand Down

0 comments on commit e58e601

Please sign in to comment.