Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal committed Feb 20, 2025
1 parent 501f8a1 commit b6e2765
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
32 changes: 24 additions & 8 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3846,15 +3846,24 @@ def _validate_metadata(metadata: dict):
class VectorIndexReader:
def __init__(self, dataset: LanceDataset, index_name: str):
stats = dataset.stats.index_stats(index_name)

self.dataset = dataset
self.index_name = index_name
self.stats = stats
try:
self.num_partitions()
except KeyError:
raise ValueError(f"Index {index_name} is not vector index")

def num_partitions(self) -> int:
"""
Returns the number of partitions in the index
Returns the number of partitions in the dataset.
Returns
-------
int
The number of partitions.
"""

return self.stats["indices"][0]["num_partitions"]

def centroids(self) -> np.ndarray:
Expand All @@ -3873,11 +3882,11 @@ def centroids(self) -> np.ndarray:
self.dataset._ds.get_index_centroids(self.stats["indices"][0]["centroids"])
)

def partition_reader(
def read_partition(
self, partition_id: int, *, with_vector: bool = False
) -> pa.RecordBatchReader:
) -> pa.Table:
"""
Returns a reader for the given IVF partition
Returns a pyarrow table for the given IVF partition
Parameters
----------
Expand All @@ -3889,10 +3898,17 @@ def partition_reader(
Returns
-------
pa.RecordBatchReader
A record batch reader for the given partition
pa.Table
A pyarrow table for the given partition,
containing the row IDs, and quantized vectors (if with_vector is True).
"""

if partition_id < 0 or partition_id >= self.num_partitions():
raise IndexError(
f"Partition id {partition_id} is out of range, "
f"expected 0 <= partition_id < {self.num_partitions()}"
)

return self.dataset._ds.read_index_partition(
self.index_name, partition_id, with_vector
)
).read_all()
26 changes: 16 additions & 10 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,18 +1139,24 @@ def test_read_partition(indexed_dataset):
num_rows = indexed_dataset.count_rows()
row_sum = 0
for part_id in range(reader.num_partitions()):
part_reader = reader.partition_reader(part_id)
for batch in part_reader:
row_sum += batch.num_rows
assert "_rowid" in batch.column_names
res = reader.read_partition(part_id)
row_sum += res.num_rows
assert "_rowid" in res.column_names
assert row_sum == num_rows

row_sum = 0
for part_id in range(reader.num_partitions()):
part_reader = reader.partition_reader(part_id, with_vector=True)
for batch in part_reader:
row_sum += batch.num_rows
pq_column = batch["__pq_code"]
assert "_rowid" in batch.column_names
assert pq_column.type == pa.list_(pa.uint8(), 16)
res = reader.read_partition(part_id, with_vector=True)
row_sum += res.num_rows
pq_column = res["__pq_code"]
assert "_rowid" in res.column_names
assert pq_column.type == pa.list_(pa.uint8(), 16)
assert row_sum == num_rows

# error tests
with pytest.raises(IndexError, match="out of range"):
reader.read_partition(reader.num_partitions() + 1)

with pytest.raises(ValueError, match="not vector index"):
indexed_dataset.create_scalar_index("id", index_type="BTREE")
VectorIndexReader(indexed_dataset, "id_idx")

0 comments on commit b6e2765

Please sign in to comment.