diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 0dd3e73876..8dc27ee8ea 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3846,15 +3846,24 @@ def _validate_metadata(metadata: dict): class VectorIndexReader: def __init__(self, dataset: LanceDataset, index_name: str): stats = dataset.stats.index_stats(index_name) - self.dataset = dataset self.index_name = index_name self.stats = stats + try: + self.num_partitions() + except KeyError: + raise ValueError(f"Index {index_name} is not vector index") def num_partitions(self) -> int: """ - Returns the number of partitions in the index + Returns the number of partitions in the dataset. + + Returns + ------- + int + The number of partitions. """ + return self.stats["indices"][0]["num_partitions"] def centroids(self) -> np.ndarray: @@ -3873,11 +3882,11 @@ def centroids(self) -> np.ndarray: self.dataset._ds.get_index_centroids(self.stats["indices"][0]["centroids"]) ) - def partition_reader( + def read_partition( self, partition_id: int, *, with_vector: bool = False - ) -> pa.RecordBatchReader: + ) -> pa.Table: """ - Returns a reader for the given IVF partition + Returns a pyarrow table for the given IVF partition Parameters ---------- @@ -3889,10 +3898,17 @@ def partition_reader( Returns ------- - pa.RecordBatchReader - A record batch reader for the given partition + pa.Table + A pyarrow table for the given partition, + containing the row IDs, and quantized vectors (if with_vector is True). """ + if partition_id < 0 or partition_id >= self.num_partitions(): + raise IndexError( + f"Partition id {partition_id} is out of range, " + f"expected 0 <= partition_id < {self.num_partitions()}" + ) + return self.dataset._ds.read_index_partition( self.index_name, partition_id, with_vector - ) + ).read_all() diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index d6e337de23..12df65da3d 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1139,18 +1139,24 @@ def test_read_partition(indexed_dataset): num_rows = indexed_dataset.count_rows() row_sum = 0 for part_id in range(reader.num_partitions()): - part_reader = reader.partition_reader(part_id) - for batch in part_reader: - row_sum += batch.num_rows - assert "_rowid" in batch.column_names + res = reader.read_partition(part_id) + row_sum += res.num_rows + assert "_rowid" in res.column_names assert row_sum == num_rows row_sum = 0 for part_id in range(reader.num_partitions()): - part_reader = reader.partition_reader(part_id, with_vector=True) - for batch in part_reader: - row_sum += batch.num_rows - pq_column = batch["__pq_code"] - assert "_rowid" in batch.column_names - assert pq_column.type == pa.list_(pa.uint8(), 16) + res = reader.read_partition(part_id, with_vector=True) + row_sum += res.num_rows + pq_column = res["__pq_code"] + assert "_rowid" in res.column_names + assert pq_column.type == pa.list_(pa.uint8(), 16) assert row_sum == num_rows + + # error tests + with pytest.raises(IndexError, match="out of range"): + reader.read_partition(reader.num_partitions() + 1) + + with pytest.raises(ValueError, match="not vector index"): + indexed_dataset.create_scalar_index("id", index_type="BTREE") + VectorIndexReader(indexed_dataset, "id_idx")