diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e069a288e2..6e61f18640 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -167,8 +167,9 @@ jobs: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} rustup component add rustfmt - name: Run tests - run: | - cargo build --all-features + # Check all benches, even though we aren't going to run them. + run: | + cargo build --tests --benches --all-features --workspace cargo test --all-features windows-build: runs-on: windows-latest @@ -190,6 +191,7 @@ jobs: Add-Content $env:GITHUB_PATH "C:\protoc\bin" shell: powershell - name: Run tests + # Check all benches, even though we aren't going to run them. run: | - cargo build + cargo build --tests --benches --all-features --workspace cargo test diff --git a/Cargo.toml b/Cargo.toml index 31410916f1..49670ff3bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.14.0" +version = "0.14.1" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -44,20 +44,20 @@ categories = [ rust-version = "1.78" [workspace.dependencies] -lance = { version = "=0.14.0", path = "./rust/lance" } -lance-arrow = { version = "=0.14.0", path = "./rust/lance-arrow" } -lance-core = { version = "=0.14.0", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.14.0", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.14.0", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.14.0", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.14.0", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.14.0", path = "./rust/lance-file" } -lance-index = { version = "=0.14.0", path = "./rust/lance-index" } -lance-io = { version = "=0.14.0", path = "./rust/lance-io" } -lance-linalg = { version = "=0.14.0", path = "./rust/lance-linalg" } -lance-table = { version = "=0.14.0", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.14.0", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.14.0", path = "./rust/lance-testing" } +lance = { version = "=0.14.1", path = "./rust/lance" } +lance-arrow = { version = "=0.14.1", path = "./rust/lance-arrow" } +lance-core = { version = "=0.14.1", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.14.1", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.14.1", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.14.1", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.14.1", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.14.1", path = "./rust/lance-file" } +lance-index = { version = "=0.14.1", path = "./rust/lance-index" } +lance-io = { version = "=0.14.1", path = "./rust/lance-io" } +lance-linalg = { version = "=0.14.1", path = "./rust/lance-linalg" } +lance-table = { version = "=0.14.1", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.14.1", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.14.1", path = "./rust/lance-testing" } fsst = { version = "=0.1.0", path = "./rust/lance-encoding/compression-algo/fsst" } approx = "0.5.1" # Note that this one does not include pyarrow @@ -112,6 +112,7 @@ deepsize = "0.2.0" either = "1.0" futures = "0.3" http = "0.2.9" +hyperloglogplus = { version = "0.4.1", features = ["const-loop"] } itertools = "0.12" lazy_static = "1" log = "0.4" @@ -138,6 +139,7 @@ serde = { version = "^1" } serde_json = { version = "1" } shellexpand = "3.0" snafu = "0.7.5" +tantivy = "0.22.0" tempfile = "3" test-log = { version = "0.2.15" } tokio = { version = "1.23", features = [ diff --git a/java/pom.xml b/java/pom.xml index a4dd4f3dd3..3a57cb9043 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -10,6 +10,27 @@ pom Lance Parent + Lance Format Java API + http://lancedb.com/ + + + + Lance DB Dev Group + dev@lancedb.com + + + + + The Apache Software License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + + + + + scm:git:git@github.com:lancedb/lance.git + HEAD + scm:git:git@github.com:lancedb/lance.git + UTF-8 @@ -193,6 +214,16 @@ deploy-to-ossrh + + org.sonatype.central + central-publishing-maven-plugin + 0.4.0 + true + + ossrh + true + + org.sonatype.plugins nexus-staging-maven-plugin @@ -222,4 +253,4 @@ - \ No newline at end of file + diff --git a/protos/encodings.proto b/protos/encodings.proto index 41aac5ee98..23170bd71c 100644 --- a/protos/encodings.proto +++ b/protos/encodings.proto @@ -191,6 +191,13 @@ message Fsst { bytes symbol_table = 2; } +// An array encoding for dictionary-encoded fields +message Dictionary { + ArrayEncoding indices = 1; + ArrayEncoding items = 2; + uint32 num_dictionary_items = 3; +} + // Encodings that decode into an Arrow array message ArrayEncoding { oneof array_encoding { @@ -201,6 +208,7 @@ message ArrayEncoding { SimpleStruct struct = 5; Binary binary = 6; Fsst fsst = 7; + Dictionary dictionary = 7; } } diff --git a/python/Cargo.toml b/python/Cargo.toml index 1d262e32d3..283b85df4f 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.14.0" +version = "0.14.1" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" diff --git a/python/python/benchmarks/test_index.py b/python/python/benchmarks/test_index.py index f95e95cfe0..3a0cadd994 100644 --- a/python/python/benchmarks/test_index.py +++ b/python/python/benchmarks/test_index.py @@ -6,6 +6,7 @@ import pyarrow as pa import pyarrow.compute as pc import pytest +from lance.indices import IndicesBuilder N_DIMS = 512 @@ -95,3 +96,16 @@ def test_optimize_index( lance.write_dataset(small_table, test_large_dataset.uri, mode="append") benchmark(test_large_dataset.optimize.optimize_indices) + + +@pytest.mark.benchmark(group="optimize_index") +@pytest.mark.parametrize("num_partitions", [100, 300]) +def test_train_ivf(test_large_dataset, benchmark, num_partitions): + builder = IndicesBuilder(test_large_dataset) + benchmark.pedantic( + builder.train_ivf, + args=["vector"], + kwargs={"num_partitions": num_partitions}, + iterations=1, + rounds=1, + ) diff --git a/python/python/benchmarks/test_search.py b/python/python/benchmarks/test_search.py index b82c87146a..be9cd0d014 100644 --- a/python/python/benchmarks/test_search.py +++ b/python/python/benchmarks/test_search.py @@ -5,6 +5,7 @@ from typing import NamedTuple, Union import lance +import numpy as np import pyarrow as pa import pyarrow.compute as pc import pytest @@ -34,7 +35,14 @@ def create_table(num_rows, offset) -> pa.Table: values = pc.random(num_rows * N_DIMS).cast(pa.float32()) vectors = pa.FixedSizeListArray.from_arrays(values, N_DIMS) filterable = pa.array(range(offset, offset + num_rows)) - return pa.table({"vector": vectors, "filterable": filterable}) + categories = pa.array(np.random.randint(0, 100, num_rows)) + return pa.table( + { + "vector": vectors, + "filterable": filterable, + "category": categories, + } + ) def create_base_dataset(data_dir: Path) -> lance.LanceDataset: @@ -66,8 +74,9 @@ def create_base_dataset(data_dir: Path) -> lance.LanceDataset: ) dataset.create_scalar_index("filterable", "BTREE") + dataset.create_scalar_index("category", "BITMAP") - return dataset + return lance.dataset(tmp_path, index_cache_size=64 * 1024) def create_delete_dataset(data_dir): @@ -82,7 +91,7 @@ def create_delete_dataset(data_dir): dataset = lance.dataset(tmp_path) dataset.delete("filterable % 2 != 0") - return dataset + return lance.dataset(tmp_path, index_cache_size=64 * 1024) def create_new_rows_dataset(data_dir): @@ -98,7 +107,7 @@ def create_new_rows_dataset(data_dir): table = create_table(NEW_ROWS, offset=NUM_ROWS) dataset = lance.write_dataset(table, tmp_path, mode="append") - return dataset + return lance.dataset(tmp_path, index_cache_size=64 * 1024) class Datasets(NamedTuple): @@ -129,6 +138,8 @@ def test_knn_search(test_dataset, benchmark): q = pc.random(N_DIMS).cast(pa.float32()) result = benchmark( test_dataset.to_table, + columns=[], + with_row_id=True, nearest=dict( column="vector", q=q, @@ -141,10 +152,12 @@ def test_knn_search(test_dataset, benchmark): @pytest.mark.benchmark(group="query_ann") -def test_flat_index_search(test_dataset, benchmark): +def test_ann_no_refine(test_dataset, benchmark): q = pc.random(N_DIMS).cast(pa.float32()) result = benchmark( test_dataset.to_table, + columns=[], + with_row_id=True, nearest=dict( column="vector", q=q, @@ -156,10 +169,12 @@ def test_flat_index_search(test_dataset, benchmark): @pytest.mark.benchmark(group="query_ann") -def test_ivf_pq_index_search(test_dataset, benchmark): +def test_ann_with_refine(test_dataset, benchmark): q = pc.random(N_DIMS).cast(pa.float32()) result = benchmark( test_dataset.to_table, + columns=[], + with_row_id=True, nearest=dict( column="vector", q=q, @@ -180,6 +195,8 @@ def test_filtered_search(test_dataset, benchmark, selectivity, prefilter, use_in threshold = int(round(selectivity * NUM_ROWS)) result = benchmark( test_dataset.to_table, + columns=[], + with_row_id=True, nearest=dict( column="vector", q=q, @@ -221,11 +238,13 @@ def test_filtered_search(test_dataset, benchmark, selectivity, prefilter, use_in "greater_than_not_selective", ], ) -def test_scalar_index_prefilter(test_dataset, benchmark, filter: str): +def test_btree_index_prefilter(test_dataset, benchmark, filter: str): q = pc.random(N_DIMS).cast(pa.float32()) if filter is None: benchmark( test_dataset.to_table, + columns=[], + with_row_id=True, nearest=dict( column="vector", q=q, @@ -236,6 +255,8 @@ def test_scalar_index_prefilter(test_dataset, benchmark, filter: str): else: benchmark( test_dataset.to_table, + columns=[], + with_row_id=True, nearest=dict( column="vector", q=q, @@ -275,14 +296,119 @@ def test_scalar_index_prefilter(test_dataset, benchmark, filter: str): "greater_than_not_selective", ], ) -def test_scalar_index_search(test_dataset, benchmark, filter: str): +def test_btree_index_search(test_dataset, benchmark, filter: str): + if filter is None: + benchmark( + test_dataset.to_table, + columns=[], + with_row_id=True, + ) + else: + benchmark( + test_dataset.to_table, + columns=[], + with_row_id=True, + prefilter=True, + filter=filter, + ) + + +@pytest.mark.benchmark(group="query_ann") +@pytest.mark.parametrize( + "filter", + ( + None, + "category = 0", + "category != 0", + "category IN (0)", + "category NOT IN (0)", + "category != 0 AND category != 3 AND category != 7", + "category NOT IN (0, 3, 7)", + "category < 5", + "category > 5", + ), + ids=[ + "none", + "equality", + "not_equality", + "in_list_one", + "not_in_list_one", + "not_equality_and_chain", + "not_in_list_three", + "less_than_selective", + "greater_than_not_selective", + ], +) +def test_bitmap_index_prefilter(test_dataset, benchmark, filter: str): + q = pc.random(N_DIMS).cast(pa.float32()) + if filter is None: + benchmark( + test_dataset.to_table, + columns=[], + with_row_id=True, + nearest=dict( + column="vector", + q=q, + k=100, + nprobes=10, + ), + ) + else: + benchmark( + test_dataset.to_table, + columns=[], + with_row_id=True, + nearest=dict( + column="vector", + q=q, + k=100, + nprobes=10, + ), + prefilter=True, + filter=filter, + ) + + +@pytest.mark.benchmark(group="query_no_vec") +@pytest.mark.parametrize( + "filter", + ( + None, + "category = 0", + "category != 0", + "category IN (0)", + "category IN (0, 3, 7)", + "category NOT IN (0)", + "category != 0 AND category != 3 AND category != 7", + "category NOT IN (0, 3, 7)", + "category < 5", + "category > 5", + ), + ids=[ + "none", + "equality", + "not_equality", + "in_list_one", + "in_list_three", + "not_in_list_one", + "not_equality_and_chain", + "not_in_list_three", + "less_than_selective", + "greater_than_not_selective", + ], +) +def test_bitmap_index_search(test_dataset, benchmark, filter: str): if filter is None: benchmark( test_dataset.to_table, + columns=[], + with_row_id=True, ) else: benchmark( test_dataset.to_table, + columns=[], + with_row_id=True, prefilter=True, filter=filter, ) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 6fb011d136..a94f2bddec 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -243,6 +243,7 @@ def scanner( prefilter: bool = False, with_row_id: bool = False, use_stats: bool = True, + fast_search: bool = False, ) -> LanceScanner: """Return a Scanner that can support various pushdowns. @@ -295,6 +296,9 @@ def scanner( number of rows (or be empty) if the rows closest to the query do not match the filter. It's generally good when the filter is not very selective. + fast_search: bool, default False + If True, then the search will only be performed on the indexed data, which + yields faster search time. Notes ----- @@ -331,6 +335,7 @@ def scanner( .with_fragments(fragments) .with_row_id(with_row_id) .use_stats(use_stats) + .fast_search(fast_search) ) if nearest is not None: builder = builder.nearest(**nearest) @@ -1098,7 +1103,7 @@ def cleanup_old_versions( def create_scalar_index( self, column: str, - index_type: Literal["BTREE"], + index_type: Union[Literal["BTREE"], Literal["BITMAP"]], name: Optional[str] = None, *, replace: bool = True, @@ -1154,9 +1159,15 @@ def create_scalar_index( that use scalar indices will either have a ``ScalarIndexQuery`` relation or a ``MaterializeIndex`` operator. - Currently, the only type of scalar index available is ``BTREE``. This index - combines is inspired by the btree data structure although only the first few - layers of the btree are cached in memory. + There are two types of scalar indices available today. The most common + type is ``BTREE``. This index is inspired by the btree data structure + although only the first few layers of the btree are cached in memory. It iwll + perform well on columns with a large number of unique values and few rows per + value. + + The other index type is ``BITMAP``. This index stores a bitmap for each unique + value in the column. This index is useful for columns with a small number of + unique values and many rows per value. Note that the ``LANCE_BYPASS_SPILLING`` environment variable can be used to bypass spilling to disk. Setting this to true can avoid memory exhaustion @@ -1170,7 +1181,7 @@ def create_scalar_index( The column to be indexed. Must be a boolean, integer, float, or string column. index_type : str - The type of the index. Only ``"BTREE"`` is supported now. + The type of the index. One of ``"BTREE"`` or ``"BITMAP"``. name : str, optional The index name. If not provided, it will be generated from the column name. @@ -1221,12 +1232,10 @@ def create_scalar_index( ) index_type = index_type.upper() - if index_type != "BTREE": + if index_type not in ["BTREE", "BITMAP"]: raise NotImplementedError( - ( - 'Only "BTREE" is supported for ', - f"index_type. Received {index_type}", - ) + 'Only "BTREE" or "BITMAP" are supported for ', + f"scalar columns. Received {index_type}", ) self._ds.create_index([column], index_type, name, replace) @@ -1253,6 +1262,7 @@ def create_index( # experimental parameters ivf_centroids_file: Optional[str] = None, precomputed_partiton_dataset: Optional[str] = None, + storage_options: Optional[Dict[str, str]] = None, **kwargs, ) -> LanceDataset: """Create index on column. @@ -1306,6 +1316,9 @@ def create_index( By making this value smaller, this shuffle will consume less memory but will take longer to complete, and vice versa. + storage_options : optional, dict + Extra options that make sense for a particular storage connection. This is + used to store connection parameters like credentials, endpoint, etc. kwargs : Parameters passed to the index building process. @@ -1463,7 +1476,9 @@ def create_index( " precomputed_partiton_dataset is provided" ) if precomputed_partiton_dataset is not None: - precomputed_ds = LanceDataset(precomputed_partiton_dataset) + precomputed_ds = LanceDataset( + precomputed_partiton_dataset, storage_options=storage_options + ) if len(precomputed_ds.get_fragments()) != 1: raise ValueError( "precomputed_partiton_dataset must have only one fragment" @@ -1477,15 +1492,21 @@ def create_index( if accelerator is not None and ivf_centroids is None: # Use accelerator to train ivf centroids - from .vector import train_ivf_centroids_on_accelerator + from .vector import ( + compute_partitions, + train_ivf_centroids_on_accelerator, + ) - ivf_centroids, partitions_file = train_ivf_centroids_on_accelerator( + ivf_centroids, kmeans = train_ivf_centroids_on_accelerator( self, column[0], num_partitions, metric, accelerator, ) + partitions_file = compute_partitions( + self, column[0], kmeans, batch_size=20480 + ) kwargs["precomputed_partitions_file"] = partitions_file if (ivf_centroids is None) and (pq_codebook is not None): @@ -1560,7 +1581,9 @@ def create_index( if shuffle_partition_concurrency is not None: kwargs["shuffle_partition_concurrency"] = shuffle_partition_concurrency - self._ds.create_index(column, index_type, name, replace, kwargs) + self._ds.create_index( + column, index_type, name, replace, storage_options, kwargs + ) return self def session(self) -> Session: @@ -1589,6 +1612,7 @@ def commit( operation: LanceOperation.BaseOperation, read_version: Optional[int] = None, commit_lock: Optional[CommitLock] = None, + storage_options: Optional[Dict[str, str]] = None, ) -> LanceDataset: """Create a new version of dataset @@ -1623,6 +1647,9 @@ def commit( commit_lock : CommitLock, optional A custom commit lock. Only needed if your object store does not support atomic commits. See the user guide for more details. + storage_options : optional, dict + Extra options that make sense for a particular storage connection. This is + used to store connection parameters like credentials, endpoint, etc. Returns ------- @@ -1660,8 +1687,14 @@ def commit( f"commit_lock must be a function, got {type(commit_lock)}" ) - _Dataset.commit(base_uri, operation._to_inner(), read_version, commit_lock) - return LanceDataset(base_uri) + _Dataset.commit( + base_uri, + operation._to_inner(), + read_version, + commit_lock, + storage_options=storage_options, + ) + return LanceDataset(base_uri, storage_options=storage_options) def validate(self): """ @@ -1994,6 +2027,7 @@ def __init__(self, ds: LanceDataset): self._fragments = None self._with_row_id = False self._use_stats = True + self._fast_search = None def batch_size(self, batch_size: int) -> ScannerBuilder: """Set batch size for Scanner""" @@ -2132,6 +2166,7 @@ def nearest( nprobes: Optional[int] = None, refine_factor: Optional[int] = None, use_index: bool = True, + ef: Optional[int] = None, ) -> ScannerBuilder: q = _coerce_query_vector(q) @@ -2158,6 +2193,10 @@ def nearest( raise ValueError(f"Nprobes must be > 0 but got {nprobes}") if refine_factor is not None and int(refine_factor) < 1: raise ValueError(f"Refine factor must be 1 or more got {refine_factor}") + if ef is not None and int(ef) <= 0: + # `ef` should be >= `k`, but `k` could be None so we can't check it here + # the rust code will check it + raise ValueError(f"ef must be > 0 but got {ef}") self._nearest = { "column": column, "q": q, @@ -2166,9 +2205,19 @@ def nearest( "nprobes": nprobes, "refine_factor": refine_factor, "use_index": use_index, + "ef": ef, } return self + def fast_search(self, flag: bool) -> ScannerBuilder: + """Enable fast search, which only perform search on the indexed data. + + Users can use `Table::optimize()` or `create_index()` to include the new data + into index, thus make new data searchable. + """ + self.fast_search = flag + return self + def to_scanner(self) -> LanceScanner: scanner = self.ds._ds.scanner( self._columns, @@ -2186,6 +2235,7 @@ def to_scanner(self) -> LanceScanner: self._with_row_id, self._use_stats, self._substrait_filter, + self._fast_search, ) return LanceScanner(scanner, self.ds) diff --git a/python/python/lance/file.py b/python/python/lance/file.py index 2ea0daa266..fb01d84f8c 100644 --- a/python/python/lance/file.py +++ b/python/python/lance/file.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors -from typing import Union +from typing import Optional, Union import pyarrow as pa @@ -146,9 +146,9 @@ class LanceFileWriter: def __init__( self, path: str, - schema: pa.Schema, + schema: Optional[pa.Schema] = None, *, - data_cache_bytes: int = None, + data_cache_bytes: Optional[int] = None, **kwargs, ): """ @@ -160,7 +160,9 @@ def __init__( The path to write to. Can be a pathname for local storage or a URI for remote storage. schema: pa.Schema - The schema of data that will be written + The schema of data that will be written. If not specified then + the schema will be inferred from the first batch. If the schema + is not specified and no data is written then the write will fail. data_cache_bytes: int How many bytes (per column) to cache before writing a page. The default is an appropriate value based on the filesystem. diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index fb78cd4d8a..fea94b4869 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -147,6 +147,7 @@ def create( mode: str = "append", *, use_legacy_format=True, + storage_options: Optional[Dict[str, str]] = None, ) -> FragmentMetadata: """Create a :class:`FragmentMetadata` from the given data. @@ -180,6 +181,9 @@ def create( use_legacy_format: bool, default True Use the legacy format to write Lance files. The default is True while the v2 format is still in beta. + storage_options : optional, dict + Extra options that make sense for a particular storage connection. This is + used to store connection parameters like credentials, endpoint, etc. See Also -------- @@ -219,6 +223,7 @@ def create( progress=progress, mode=mode, use_legacy_format=use_legacy_format, + storage_options=storage_options, ) return FragmentMetadata(inner_meta.json()) diff --git a/python/python/lance/indices.py b/python/python/lance/indices.py new file mode 100644 index 0000000000..d77649447c --- /dev/null +++ b/python/python/lance/indices.py @@ -0,0 +1,421 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors +import math +import warnings +from typing import TYPE_CHECKING, Optional, Union + +import pyarrow as pa + +from lance.file import LanceFileReader, LanceFileWriter + +from .lance import indices + +if TYPE_CHECKING: + from .dependencies import torch + + +class PqModel: + """ + A class that represents a trained PQ model + + Can be saved / loaded to checkpoint progress. + """ + + def __init__(self, num_subvectors: int, codebook: pa.FixedSizeListArray): + self.num_subvectors = num_subvectors + """The number of subvectors to divide source vectors into""" + self.codebook = codebook + """The centroids of the PQ clusters""" + + @property + def dimension(self): + """The dimension of the vectors this model was trained on""" + return self.codebook.type.list_size + + def save(self, uri: str): + """ + Save the PQ model to a lance file. + + Parameters + ---------- + + uri: str + The URI to save the model to. The URI can be a local file path or a + cloud storage path. + """ + with LanceFileWriter( + uri, + pa.schema( + [pa.field("codebook", self.codebook.type)], + metadata={b"num_subvectors": str(self.num_subvectors).encode()}, + ), + ) as writer: + batch = pa.table([self.codebook], names=["codebook"]) + writer.write_batch(batch) + + @classmethod + def load(cls, uri: str): + """ + Load a PQ model from a lance file. + + Parameters + ---------- + + uri: str + The URI to load the model from. The URI can be a local file path or a + cloud storage path. + """ + reader = LanceFileReader(uri) + num_rows = reader.metadata().num_rows + metadata = reader.metadata().schema.metadata + num_subvectors = int(metadata[b"num_subvectors"].decode()) + codebook = ( + reader.read_all(batch_size=num_rows).to_table().column("codebook").chunk(0) + ) + return cls(num_subvectors, codebook) + + +class IvfModel: + """ + A class that represents a trained IVF model. + """ + + def __init__(self, centroids: pa.Array, distance_type: str): + self.centroids = centroids + """The centroids of the IVF clusters""" + self.distance_type = distance_type + """The distance type used to train the IVF model""" + + @property + def num_partitions(self) -> int: + """ + The number of partitions / centroids in the IVF model + """ + return len(self.centroids) + + def save(self, uri: str): + """ + Save the IVF model to a lance file. + + Parameters + ---------- + + uri: str + The URI to save the model to. The URI can be a local file path or a + cloud storage path. + """ + with LanceFileWriter( + uri, + pa.schema( + [pa.field("centroids", self.centroids.type)], + metadata={b"distance_type": self.distance_type.encode()}, + ), + ) as writer: + batch = pa.table([self.centroids], names=["centroids"]) + writer.write_batch(batch) + + @classmethod + def load(cls, uri: str): + """ + Load an IVF model from a lance file. + + Parameters + ---------- + + uri: str + The URI to load the model from. The URI can be a local file path or a + cloud storage path. + """ + reader = LanceFileReader(uri) + num_rows = reader.metadata().num_rows + metadata = reader.metadata().schema.metadata + distance_type = metadata[b"distance_type"].decode() + centroids = ( + reader.read_all(batch_size=num_rows).to_table().column("centroids").chunk(0) + ) + return cls(centroids, distance_type) + + +class IndicesBuilder: + """ + A class with helper functions for building indices on a dataset. + + This methods in this class can break down the process of building indices into + smaller steps. This can be useful for debugging and checkpointing when building + indices for extremely large datasets. + + This class is intended for advanced users that need to create vector indices at + large scales. + + The methods in this class are **experimental** and may change in future versions. + + For datasets with 10s of millions or fewer rows it will likely be simpler to just + use the `create_index` method on the dataset object. + """ + + def __init__(self, dataset): + self.dataset = dataset + + def train_ivf( + self, + column, + num_partitions=None, + *, + distance_type="l2", + accelerator: Optional[Union[str, "torch.Device"]] = None, + sample_rate: int = 256, + max_iters: int = 50, + ) -> IvfModel: + """ + Train IVF centroids for the given vector column. + + This will run k-means clustering on the given vector column to train the IVF + centroids. This is the first step in several vector indices. The centroids + will be used to partition the vectors into different clusters. + + IVF centroids are trained from a sample of the data (determined by the + sample_rate). While this sample is not huge it might still be quite large. + + K-means is an iterative algorithm that can be computationally expensive. The + accelerator argument can be used to offload the computation to a hardware + accelerator such as a GPU or TPU. + + Parameters + ---------- + + column: str + The vector column to partition, must be a fixed size list of floats + or 1-dimensional fixed-shape tensor column. + num_partitions: int + The number of partitions to train. Large values are more expensive to + train and can lead to longer search times. Smaller values could lead to + overtraining, reduced recall, and require large nprobes values. If not + specified the default will be the integer nearest the square root of the + number of rows. + distance_type: "l2" | "dot" | "cosine" + The distance type to used. This is defined in more detail in the LanceDB + documentation on creating indices. + accelerator: str | torch.Device + An optional accelerator to use to offload computation to specialized + hardware. Currently supported values are "cuda" and "mps". + sample_rate: int + IVF is trained on a random sample of the dataset. The sample_rate + determines the size of this sample. There will be sample_rate rows loaded + for each partition for a total of sample_rate * num_partitions rows. If + the dataset does not contain enough rows an error will be raised. + max_iters: int + K-means is an iterative algorithm that is run until it converges. In + some cases, k-means will not converge but will cycle between various + possible minima. In these cases we must terminate or run forever. The + max_iters parameter defines a cutoff at which we terminate training. + """ + column = self._normalize_column(column) + num_rows = self.dataset.count_rows() + num_partitions = self._determine_num_partitions(num_partitions, num_rows) + self._verify_ivf_sample_rate(sample_rate, num_partitions, num_rows) + distance_type = self._normalize_distance_type(distance_type) + self._verify_ivf_params(num_partitions) + + if accelerator is None: + dimension = self.dataset.schema.field(column[0]).type.list_size + ivf_centroids = indices.train_ivf_model( + self.dataset._ds, + column[0], + dimension, + num_partitions, + distance_type, + sample_rate, + max_iters, + ) + return IvfModel(ivf_centroids, distance_type) + else: + # Use accelerator to train ivf centroids + from .vector import train_ivf_centroids_on_accelerator + + ivf_centroids, _ = train_ivf_centroids_on_accelerator( + self.dataset, + column[0], + num_partitions, + distance_type, + accelerator, + sample_rate=sample_rate, + max_iters=max_iters, + ) + num_dims = ivf_centroids.shape[1] + ivf_centroids.shape = -1 + flat_centroids_array = pa.array(ivf_centroids) + centroids_array = pa.FixedSizeListArray.from_arrays( + flat_centroids_array, num_dims + ) + return IvfModel(centroids_array, distance_type) + + def train_pq( + self, + column, + ivf_model: IvfModel, + num_subvectors=None, + *, + sample_rate: int = 256, + max_iters: int = 50, + ) -> PqModel: + """ + Train a PQ model for a given column. + + This will run k-means clustering on each subvector to determine the centroids + that will be used to quantize the subvectors. This step runs against a + randomly chosen sample of the data. The sample size is typically quite small + and PQ training is relatively fast regardless of dataset scale. As a result, + accelerators are not needed here. + + Parameters + ---------- + + column: str + The vector column to quantize, must be a fixed size list of floats + or 1-dimensional fixed-shape tensor column. + ivf_model: IvfModel + The IVF model to use to partition the vectors into clusters. This is + needed because PQ is trained on residuals from the IVF model. + num_subvectors: int + The number of subvectors to divide the source vectors into. This must be + a divisor of the vector dimension. If not specified the default will be + the vector dimension divided by 16 if the dimension is divisible by 16, + otherwise the vector dimension divided by 8 if the dimension is divisible + by 8. + + Automatic calculation of num_subvectors will fail if the vector dimension + is not divisible by 16 or 8. In this case you must specify num_subvectors + manually (though any value you choose is likely to lead to poor performance) + sample_rate: int + This parameter is used in the same way as in the IVF model. + max_iters: int + This parameter is used in the same way as in the IVF model. + """ + column = self._normalize_column(column) + num_rows = self.dataset.count_rows() + dimension = self.dataset.schema.field(column[0]).type.list_size + self.dataset.schema.field(column[0]).type.list_size + num_subvectors = self._normalize_pq_params(num_subvectors, dimension) + self._verify_pq_sample_rate(num_rows, sample_rate) + distance_type = ivf_model.distance_type + pq_codebook = indices.train_pq_model( + self.dataset._ds, + column[0], + dimension, + num_subvectors, + distance_type, + sample_rate, + max_iters, + ivf_model.centroids, + ) + return PqModel(num_subvectors, pq_codebook) + + def _determine_num_partitions(self, num_partitions: Optional[int], num_rows: int): + if num_partitions is None: + return round(math.sqrt(num_rows)) + return num_partitions + + def _normalize_pq_params(self, num_subvectors: int, dimension: int): + if num_subvectors is None: + if dimension % 16 == 0: + return dimension // 16 + elif dimension % 8 == 0: + return dimension // 8 + else: + raise ValueError( + f"vector dimension {dimension} is not divisible by 16 or 8." + " PQ performance will be poor. Cowardly refusing to create" + " PQ model. Please specify num_subvectors manually." + ) + if not isinstance(num_subvectors, int): + raise ValueError("num_subvectors must be an int") + if num_subvectors < 1: + raise ValueError("num_subvectors must be greater than 0") + if num_subvectors > dimension: + raise ValueError( + "num_subvectors must be less than or equal to the dimension of" + " the vectors" + ) + if dimension % num_subvectors != 0: + raise ValueError( + "dimension ({dimension}) must be divisible by num_subvectors" + " ({num_subvectors}) without remainder" + ) + return num_subvectors + + def _verify_base_sample_rate(self, sample_rate: int): + if not isinstance(sample_rate, int) or sample_rate < 2: + raise ValueError( + f"The sample_rate must be an int greater than 1, got {sample_rate}" + ) + + def _verify_pq_sample_rate(self, num_rows: int, sample_rate: int): + self._verify_base_sample_rate(sample_rate) + if 256 * sample_rate > num_rows: + raise ValueError( + "There are not enough rows in the dataset to create PQ" + f" codebook with a sample rate of {sample_rate}. {sample_rate * 256}" + f" rows needed and there are {num_rows}" + ) + + def _verify_ivf_sample_rate( + self, sample_rate: int, num_partitions: int, num_rows: int + ): + self._verify_base_sample_rate(sample_rate) + if num_partitions * sample_rate > num_rows: + raise ValueError( + "There are not enough rows in the dataset to create IVF centroids with" + f" {num_partitions} partitions and a sample rate of {sample_rate}." + f" {sample_rate * num_partitions} rows needed and there are {num_rows}" + ) + + def _verify_ivf_params(self, num_partitions): + if num_partitions is None: + raise ValueError( + "num_partitions and num_sub_vectors are required for IVF_PQ" + ) + if isinstance(num_partitions, float): + warnings.warn("num_partitions is float, converting to int") + num_partitions = int(num_partitions) + elif not isinstance(num_partitions, int): + raise TypeError(f"num_partitions must be int, got {type(num_partitions)}") + + def _normalize_distance_type(self, distance_type): + if not isinstance(distance_type, str) or distance_type.lower() not in [ + "l2", + "cosine", + "euclidean", + "dot", + ]: + raise ValueError(f"Distance type {distance_type} not supported.") + return distance_type.lower() + + def _normalize_column(self, column): + # Only support building index for 1 column from the API aspect, however + # the internal implementation might support building multi-column index later. + if isinstance(column, str): + column = [column] + + # validate args + for c in column: + if c not in self.dataset.schema.names: + raise KeyError(f"{c} not found in schema") + field = self.dataset.schema.field(c) + if not ( + pa.types.is_fixed_size_list(field.type) + or ( + isinstance(field.type, pa.FixedShapeTensorType) + and len(field.type.shape) == 1 + ) + ): + raise TypeError( + f"Vector column {c} must be FixedSizeListArray " + f"1-dimensional FixedShapeTensorArray, got {field.type}" + ) + if not pa.types.is_floating(field.type.value_type): + raise TypeError( + f"Vector column {c} must have floating value type, " + f"got {field.type.value_type}" + ) + + return column diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 85bb69003d..f66b6e7cc2 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -37,9 +37,9 @@ class LanceFileWriter: def __init__( self, path: str, - schema: pa.Schema, - data_cache_bytes: int, - keep_original_array: bool, + schema: Optional[pa.Schema], + data_cache_bytes: Optional[int], + keep_original_array: Optional[bool], ): ... def write_batch(self, batch: pa.RecordBatch) -> None: ... def finish(self) -> int: ... diff --git a/python/python/lance/lance/indices/__init__.pyi b/python/python/lance/lance/indices/__init__.pyi new file mode 100644 index 0000000000..d09f24dd75 --- /dev/null +++ b/python/python/lance/lance/indices/__init__.pyi @@ -0,0 +1,35 @@ +# Copyright (c) 2024. Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa + +def train_ivf_model( + dataset, + column: str, + dimension: int, + num_partitions: int, + distance_type: str, + sample_rate: int, + max_iters: int, +) -> pa.Array: ... +def train_pq_model( + dataset, + column: str, + dimension: int, + num_subvectors: int, + distance_type: str, + sample_rate: int, + max_iters: int, + ivf_model: pa.Array, +) -> pa.Array: ... diff --git a/python/python/lance/vector.py b/python/python/lance/vector.py index 4d994fcc95..dc7de7d09d 100644 --- a/python/python/lance/vector.py +++ b/python/python/lance/vector.py @@ -183,7 +183,7 @@ def train_ivf_centroids_on_accelerator( np.save(f, centroids) logging.info("Saved centroids to %s", f.name) - return centroids, compute_partitions(dataset, column, kmeans, batch_size=20480) + return centroids, kmeans def compute_partitions( diff --git a/python/python/tests/test_file.py b/python/python/tests/test_file.py index cbc0e149e7..e5efa09b1d 100644 --- a/python/python/tests/test_file.py +++ b/python/python/tests/test_file.py @@ -16,6 +16,32 @@ def test_file_writer(tmp_path): assert metadata.num_rows == 3 +def test_write_no_schema(tmp_path): + path = tmp_path / "foo.lance" + with LanceFileWriter(str(path)) as writer: + writer.write_batch(pa.table({"a": [1, 2, 3]})) + reader = LanceFileReader(str(path)) + assert reader.read_all().to_table() == pa.table({"a": [1, 2, 3]}) + + +def test_no_schema_no_data(tmp_path): + path = tmp_path / "foo.lance" + with pytest.raises( + ValueError, match="Schema is unknown and file cannot be created" + ): + with LanceFileWriter(str(path)) as _: + pass + + +def test_schema_only(tmp_path): + path = tmp_path / "foo.lance" + schema = pa.schema([pa.field("a", pa.int64())]) + with LanceFileWriter(str(path), schema=schema) as _: + pass + reader = LanceFileReader(str(path)) + assert reader.metadata().schema == schema + + def test_aborted_write(tmp_path): path = tmp_path / "foo.lance" schema = pa.schema([pa.field("a", pa.int64())]) diff --git a/python/python/tests/test_indices.py b/python/python/tests/test_indices.py new file mode 100644 index 0000000000..e04a83f79a --- /dev/null +++ b/python/python/tests/test_indices.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors +import lance +import numpy as np +import pyarrow as pa +import pytest +from lance.indices import IndicesBuilder, IvfModel, PqModel + + +def gen_dataset(tmpdir, datatype=np.float32): + vectors = np.random.randn(10000, 128).astype(datatype) + vectors.shape = -1 + vectors = pa.FixedSizeListArray.from_arrays(vectors, 128) + table = pa.Table.from_arrays([vectors], names=["vectors"]) + ds = lance.write_dataset(table, str(tmpdir / "dataset")) + + return ds + + +def test_ivf_centroids(tmpdir): + ds = gen_dataset(tmpdir) + + ivf = IndicesBuilder(ds).train_ivf("vectors", sample_rate=16) + + assert ivf.distance_type == "l2" + assert len(ivf.centroids) == 100 + + ivf.save(str(tmpdir / "ivf")) + reloaded = IvfModel.load(str(tmpdir / "ivf")) + assert reloaded.distance_type == "l2" + assert ivf.centroids == reloaded.centroids + + +@pytest.mark.cuda +def test_ivf_centroids_cuda(tmpdir): + ds = gen_dataset(tmpdir) + ivf = IndicesBuilder(ds).train_ivf("vectors", sample_rate=16, accelerator="cuda") + + assert ivf.distance_type == "l2" + assert len(ivf.centroids) == 100 + + +def test_ivf_centroids_column_type(tmpdir): + def check(column_type, typename): + ds = gen_dataset(tmpdir / typename, column_type) + ivf = IndicesBuilder(ds).train_ivf("vectors", sample_rate=16) + assert len(ivf.centroids) == 100 + ivf.save(str(tmpdir / f"ivf_{typename}")) + reloaded = IvfModel.load(str(tmpdir / f"ivf_{typename}")) + assert ivf.centroids == reloaded.centroids + + check(np.float16, "f16") + check(np.float32, "f32") + check(np.float64, "f64") + + +def test_ivf_centroids_distance_type(tmpdir): + ds = gen_dataset(tmpdir) + + def check(distance_type): + ivf = IndicesBuilder(ds).train_ivf( + "vectors", sample_rate=16, distance_type=distance_type + ) + assert ivf.distance_type == distance_type + ivf.save(str(tmpdir / "ivf")) + reloaded = IvfModel.load(str(tmpdir / "ivf")) + assert reloaded.distance_type == distance_type + + check("l2") + check("cosine") + check("dot") + + +def test_num_partitions(tmpdir): + ds = gen_dataset(tmpdir) + + ivf = IndicesBuilder(ds).train_ivf("vectors", sample_rate=16, num_partitions=10) + assert ivf.num_partitions == 10 + + +@pytest.fixture +def ds_with_ivf(tmpdir): + ds = gen_dataset(tmpdir) + ivf = IndicesBuilder(ds).train_ivf("vectors", sample_rate=16) + return ds, ivf + + +def test_gen_pq(tmpdir, ds_with_ivf): + ds, ivf = ds_with_ivf + + pq = IndicesBuilder(ds).train_pq("vectors", ivf, sample_rate=16) + assert pq.dimension == 128 + assert pq.num_subvectors == 8 + + pq.save(str(tmpdir / "pq")) + reloaded = PqModel.load(str(tmpdir / "pq")) + assert pq.dimension == reloaded.dimension + assert pq.codebook == reloaded.codebook diff --git a/python/python/tests/test_tf.py b/python/python/tests/test_tf.py index ae86702b33..878a8b2942 100644 --- a/python/python/tests/test_tf.py +++ b/python/python/tests/test_tf.py @@ -4,12 +4,16 @@ import os import warnings +import lance import ml_dtypes import numpy as np import pandas as pd import pyarrow as pa import pytest from lance.arrow import BFloat16Type, ImageArray, bfloat16_array +from lance.fragment import LanceFragment + +pytest.skip("Skip tensorflow tests", allow_module_level=True) try: with warnings.catch_warnings(): @@ -22,15 +26,13 @@ allow_module_level=True, ) -import lance -from lance.fragment import LanceFragment -from lance.tf.data import ( +from lance.tf.data import ( # noqa: E402 from_lance, from_lance_batches, lance_fragments, lance_take_batches, ) -from lance.tf.tfrecord import infer_tfrecord_schema, read_tfrecord +from lance.tf.tfrecord import infer_tfrecord_schema, read_tfrecord # noqa: E402 @pytest.fixture diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 26e0a4ab2d..8715bbc2fc 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -605,7 +605,7 @@ def query_index(ds, ntimes, q=None): indexed_dataset = lance.dataset(tmp_path / "test", index_cache_size=1) # query using the same vector, we should get a very high hit rate - query_index(indexed_dataset, 100, q=rng.standard_normal(16)) + query_index(indexed_dataset, 200, q=rng.standard_normal(16)) assert indexed_dataset._ds.index_cache_hit_rate() > 0.99 last_hit_rate = indexed_dataset._ds.index_cache_hit_rate() diff --git a/python/src/dataset.rs b/python/src/dataset.rs index ab3b0dca9d..e2157126f9 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -37,6 +37,7 @@ use lance::dataset::{ WriteParams, }; use lance::dataset::{BatchInfo, BatchUDF, NewColumnTransform, UDFCheckpointStore}; +use lance::index::scalar::ScalarIndexType; use lance::index::{scalar::ScalarIndexParams, vector::VectorIndexParams}; use lance_arrow::as_fixed_size_list_array; use lance_core::datatypes::Schema; @@ -432,6 +433,7 @@ impl Dataset { with_row_id: Option, use_stats: Option, substrait_filter: Option>, + fast_search: Option, ) -> PyResult { let mut scanner: LanceScanner = self_.ds.scan(); match (columns, columns_with_transform) { @@ -496,6 +498,10 @@ impl Dataset { scanner.use_stats(use_stats); } + if let Some(true) = fast_search { + scanner.fast_search(); + } + if let Some(fragments) = fragments { let fragments = fragments .into_iter() @@ -573,6 +579,16 @@ impl Dataset { true }; + let ef: Option = if let Some(ef) = nearest.get_item("ef")? { + if ef.is_none() { + None + } else { + PyAny::downcast::(ef)?.extract()? + } + } else { + None + }; + scanner .nearest(column.as_str(), &q, k) .map(|s| { @@ -583,6 +599,9 @@ impl Dataset { if let Some(m) = metric_type { s = s.distance_metric(m); } + if let Some(ef) = ef { + s = s.ef(ef); + } s.use_index(use_index); s }) @@ -910,11 +929,12 @@ impl Dataset { index_type: &str, name: Option, replace: Option, + storage_options: Option>, kwargs: Option<&PyDict>, ) -> PyResult<()> { let index_type = index_type.to_uppercase(); let idx_type = match index_type.as_str() { - "BTREE" => IndexType::Scalar, + "BTREE" | "BITMAP" => IndexType::Scalar, "IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector, _ => { return Err(PyValueError::new_err(format!( @@ -926,12 +946,17 @@ impl Dataset { // Only VectorParams are supported. let params: Box = if index_type == "BTREE" { Box::::default() + } else if index_type == "BITMAP" { + Box::new(ScalarIndexParams { + // Temporary workaround until we add support for auto-detection of scalar index type + force_index_type: Some(ScalarIndexType::Bitmap), + }) } else { let column_type = match self.ds.schema().field(columns[0]) { Some(f) => f.data_type().clone(), None => return Err(PyValueError::new_err("Column not found in dataset schema.")), }; - prepare_vector_index_params(&index_type, &column_type, kwargs)? + prepare_vector_index_params(&index_type, &column_type, storage_options, kwargs)? }; let replace = replace.unwrap_or(true); @@ -994,7 +1019,13 @@ impl Dataset { operation: Operation, read_version: Option, commit_lock: Option<&PyAny>, + storage_options: Option>, ) -> PyResult { + let object_store_params = storage_options.map(|storage_options| ObjectStoreParams { + storage_options: Some(storage_options), + ..Default::default() + }); + let commit_handler = commit_lock.map(|commit_lock| { Arc::new(PyCommitLock::new(commit_lock.to_object(commit_lock.py()))) as Arc @@ -1008,8 +1039,14 @@ impl Dataset { }; let manifest = dataset.as_ref().map(|ds| ds.manifest()); validate_operation(manifest, &operation.0)?; - LanceDataset::commit(dataset_uri, operation.0, read_version, None, commit_handler) - .await + LanceDataset::commit( + dataset_uri, + operation.0, + read_version, + object_store_params, + commit_handler, + ) + .await })? .map_err(|e| PyIOError::new_err(e.to_string()))?; Ok(Self { @@ -1208,6 +1245,7 @@ pub fn get_write_params(options: &PyDict) -> PyResult> { fn prepare_vector_index_params( index_type: &str, column_type: &DataType, + storage_options: Option>, kwargs: Option<&PyDict>, ) -> PyResult> { let mut m_type = MetricType::L2; @@ -1236,6 +1274,10 @@ fn prepare_vector_index_params( ivf_params.num_partitions = PyAny::downcast::(n)?.extract()? }; + if let Some(n) = kwargs.get_item("shuffle_partition_concurrency")? { + ivf_params.shuffle_partition_concurrency = PyAny::downcast::(n)?.extract()? + }; + if let Some(c) = kwargs.get_item("ivf_centroids")? { let batch = RecordBatch::from_pyarrow(c)?; if "_ivf_centroids" != batch.schema().field(0).name() { @@ -1266,6 +1308,10 @@ fn prepare_vector_index_params( ivf_params.precomputed_partitons_file = Some(f.to_string()); }; + if let Some(storage_options) = storage_options { + ivf_params.storage_options = Some(storage_options); + } + match ( kwargs.get_item("precomputed_shuffle_buffers")?, kwargs.get_item("precomputed_shuffle_buffers_path")? diff --git a/python/src/file.rs b/python/src/file.rs index b7a91a346a..71cbc98e4e 100644 --- a/python/src/file.rs +++ b/python/src/file.rs @@ -172,24 +172,23 @@ pub struct LanceFileWriter { impl LanceFileWriter { async fn open( uri_or_path: String, - schema: PyArrowType, + schema: Option>, data_cache_bytes: Option, keep_original_array: Option, ) -> PyResult { let (object_store, path) = object_store_from_uri_or_path(uri_or_path).await?; let object_writer = object_store.create(&path).await.infer_error()?; - let lance_schema = lance_core::datatypes::Schema::try_from(&schema.0).infer_error()?; - let inner = FileWriter::try_new( - object_writer, - path.to_string(), - lance_schema, - FileWriterOptions { - data_cache_bytes, - keep_original_array, - ..Default::default() - }, - ) - .infer_error()?; + let options = FileWriterOptions { + data_cache_bytes, + keep_original_array, + ..Default::default() + }; + let inner = if let Some(schema) = schema { + let lance_schema = lance_core::datatypes::Schema::try_from(&schema.0).infer_error()?; + FileWriter::try_new(object_writer, lance_schema, options).infer_error() + } else { + Ok(FileWriter::new_lazy(object_writer, options)) + }?; Ok(Self { inner: Box::new(inner), }) @@ -201,7 +200,7 @@ impl LanceFileWriter { #[new] pub fn new( path: String, - schema: PyArrowType, + schema: Option>, data_cache_bytes: Option, keep_original_array: Option, ) -> PyResult { diff --git a/python/src/indices.rs b/python/src/indices.rs new file mode 100644 index 0000000000..c6047d8e4c --- /dev/null +++ b/python/src/indices.rs @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use arrow::pyarrow::{PyArrowType, ToPyArrow}; +use arrow_array::{Array, FixedSizeListArray}; +use arrow_data::ArrayData; +use lance_index::vector::{ + ivf::{storage::IvfModel, IvfBuildParams}, + pq::PQBuildParams, +}; +use lance_linalg::distance::DistanceType; +use pyo3::{pyfunction, types::PyModule, wrap_pyfunction, PyObject, PyResult, Python}; + +use crate::{dataset::Dataset, error::PythonErrorExt, RT}; + +async fn do_train_ivf_model( + dataset: &Dataset, + column: &str, + dimension: usize, + num_partitions: u32, + distance_type: &str, + sample_rate: u32, + max_iters: u32, +) -> PyResult { + // We verify distance_type earlier so can unwrap here + let distance_type = DistanceType::try_from(distance_type).unwrap(); + let params = IvfBuildParams { + max_iters: max_iters as usize, + sample_rate: sample_rate as usize, + num_partitions: num_partitions as usize, + ..Default::default() + }; + let ivf_model = lance::index::vector::ivf::build_ivf_model( + dataset.ds.as_ref(), + column, + dimension, + distance_type, + ¶ms, + ) + .await + .infer_error()?; + let centroids = ivf_model.centroids.unwrap(); + Ok(centroids.into_data()) +} + +#[pyfunction] +#[allow(clippy::too_many_arguments)] +fn train_ivf_model( + py: Python<'_>, + dataset: &Dataset, + column: &str, + dimension: usize, + num_partitions: u32, + distance_type: &str, + sample_rate: u32, + max_iters: u32, +) -> PyResult { + let centroids = RT.block_on( + Some(py), + do_train_ivf_model( + dataset, + column, + dimension, + num_partitions, + distance_type, + sample_rate, + max_iters, + ), + )??; + centroids.to_pyarrow(py) +} + +#[allow(clippy::too_many_arguments)] +async fn do_train_pq_model( + dataset: &Dataset, + column: &str, + dimension: usize, + num_subvectors: u32, + distance_type: &str, + sample_rate: u32, + max_iters: u32, + ivf_model: IvfModel, +) -> PyResult { + // We verify distance_type earlier so can unwrap here + let distance_type = DistanceType::try_from(distance_type).unwrap(); + let params = PQBuildParams { + num_sub_vectors: num_subvectors as usize, + num_bits: 8, + max_iters: max_iters as usize, + sample_rate: sample_rate as usize, + ..Default::default() + }; + let pq_model = lance::index::vector::pq::build_pq_model( + dataset.ds.as_ref(), + column, + dimension, + distance_type, + ¶ms, + Some(&ivf_model), + ) + .await + .infer_error()?; + Ok(pq_model.codebook.into_data()) +} + +#[pyfunction] +#[allow(clippy::too_many_arguments)] +fn train_pq_model( + py: Python<'_>, + dataset: &Dataset, + column: &str, + dimension: usize, + num_subvectors: u32, + distance_type: &str, + sample_rate: u32, + max_iters: u32, + ivf_centroids: PyArrowType, +) -> PyResult { + let ivf_centroids = ivf_centroids.0; + let ivf_centroids = FixedSizeListArray::from(ivf_centroids); + let ivf_model = IvfModel { + centroids: Some(ivf_centroids), + offsets: vec![], + lengths: vec![], + }; + let codebook = RT.block_on( + Some(py), + do_train_pq_model( + dataset, + column, + dimension, + num_subvectors, + distance_type, + sample_rate, + max_iters, + ivf_model, + ), + )??; + codebook.to_pyarrow(py) +} + +pub fn register_indices(py: Python, m: &PyModule) -> PyResult<()> { + let indices = PyModule::new(py, "indices")?; + indices.add_wrapped(wrap_pyfunction!(train_ivf_model))?; + indices.add_wrapped(wrap_pyfunction!(train_pq_model))?; + m.add_submodule(indices)?; + Ok(()) +} diff --git a/python/src/lib.rs b/python/src/lib.rs index b90740cf5f..f3589d880a 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -57,6 +57,7 @@ pub(crate) mod error; pub(crate) mod executor; pub(crate) mod file; pub(crate) mod fragment; +pub(crate) mod indices; pub(crate) mod reader; pub(crate) mod scanner; pub(crate) mod schema; @@ -74,6 +75,7 @@ pub use dataset::write_dataset; pub use dataset::{Dataset, Operation}; pub use fragment::FragmentMetadata; use fragment::{DataFile, FileFragment}; +pub use indices::register_indices; pub use reader::LanceReader; pub use scanner::Scanner; @@ -148,6 +150,7 @@ fn lance(py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(debug::list_transactions))?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; register_datagen(py, m)?; + register_indices(py, m)?; Ok(()) } diff --git a/rust/lance-core/Cargo.toml b/rust/lance-core/Cargo.toml index 2f53a0f480..0e0f219bdb 100644 --- a/rust/lance-core/Cargo.toml +++ b/rust/lance-core/Cargo.toml @@ -40,6 +40,7 @@ tokio-stream.workspace = true tokio-util.workspace = true tracing.workspace = true url.workspace = true +log.workspace = true # This is used to detect CPU features at runtime. # See src/utils/cpu.rs diff --git a/rust/lance-core/src/utils/tokio.rs b/rust/lance-core/src/utils/tokio.rs index a7a1af1ac2..f2d46ec281 100644 --- a/rust/lance-core/src/utils/tokio.rs +++ b/rust/lance-core/src/utils/tokio.rs @@ -1,16 +1,40 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::time::Duration; + use crate::Result; use futures::{Future, FutureExt}; use tokio::runtime::{Builder, Runtime}; use tracing::Span; +fn get_num_compute_intensive_cpus() -> usize { + let cpus = num_cpus::get(); + + if cpus <= *IO_CORE_RESERVATION { + // on systems with only 1 CPU there is no point in warning + if cpus > 1 { + log::warn!( + "Number of CPUs is less than or equal to the number of IO core reservations. \ + This is not a supported configuration. using 1 CPU for compute intensive tasks." + ); + } + return 1; + } + + num_cpus::get() - *IO_CORE_RESERVATION +} + lazy_static::lazy_static! { + pub static ref IO_CORE_RESERVATION: usize = std::env::var("LANCE_IO_CORE_RESERVATION").unwrap_or("2".to_string()).parse().unwrap(); + pub static ref CPU_RUNTIME: Runtime = Builder::new_multi_thread() .thread_name("lance-cpu") - .max_blocking_threads(num_cpus::get()) + .max_blocking_threads(get_num_compute_intensive_cpus()) + .worker_threads(1) + // keep the thread alive "forever" + .thread_keep_alive(Duration::from_secs(u64::MAX)) .build() .unwrap(); } diff --git a/rust/lance-encoding/Cargo.toml b/rust/lance-encoding/Cargo.toml index d657782877..ec6de63ba0 100644 --- a/rust/lance-encoding/Cargo.toml +++ b/rust/lance-encoding/Cargo.toml @@ -34,6 +34,7 @@ tracing.workspace = true zstd.workspace = true rand = "0.8.4" rand_xoshiro = "0.6.0" +hyperloglogplus.workspace = true [dev-dependencies] rand.workspace = true diff --git a/rust/lance-encoding/benches/decoder.rs b/rust/lance-encoding/benches/decoder.rs index 0920c5a60e..32cd0e0f9b 100644 --- a/rust/lance-encoding/benches/decoder.rs +++ b/rust/lance-encoding/benches/decoder.rs @@ -2,13 +2,17 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::sync::Arc; -use arrow_schema::{DataType, Field, TimeUnit}; +use arrow_array::{RecordBatch, UInt32Array}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use arrow_select::take::take; use criterion::{criterion_group, criterion_main, Criterion}; use lance_encoding::{ decoder::{DecoderMiddlewareChain, FilterExpression}, encoder::{encode_batch, CoreFieldEncodingStrategy}, }; +use rand::Rng; + const PRIMITIVE_TYPES: &[DataType] = &[ DataType::Date32, DataType::Date64, @@ -126,12 +130,67 @@ fn bench_decode_fsl(c: &mut Criterion) { } } +fn bench_decode_str_with_dict_encoding(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut group = c.benchmark_group("decode_primitive"); + let data_type = DataType::Utf8; + // generate string column with 20 rows + let string_data = lance_datagen::gen() + .anon_col(lance_datagen::array::rand_type(&DataType::Utf8)) + .into_batch_rows(lance_datagen::RowCount::from(20)) + .unwrap(); + + let string_array = string_data.column(0); + + // generate random int column with 100000 rows + let mut rng = rand::thread_rng(); + let integer_arr: Vec = (0..100_000).map(|_| rng.gen_range(0..20)).collect(); + let integer_array = UInt32Array::from(integer_arr); + + let mapped_strings = take(string_array, &integer_array, None).unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "string", + DataType::Utf8, + false, + )])); + + let data = RecordBatch::try_new(schema, vec![Arc::new(mapped_strings)]).unwrap(); + + let lance_schema = + Arc::new(lance_core::datatypes::Schema::try_from(data.schema().as_ref()).unwrap()); + let input_bytes = data.get_array_memory_size(); + group.throughput(criterion::Throughput::Bytes(input_bytes as u64)); + let encoding_strategy = CoreFieldEncodingStrategy::default(); + let encoded = rt + .block_on(encode_batch( + &data, + lance_schema, + &encoding_strategy, + 1024 * 1024, + )) + .unwrap(); + let func_name = format!("{:?}", data_type).to_lowercase(); + group.bench_function(func_name, |b| { + b.iter(|| { + let batch = rt + .block_on(lance_encoding::decoder::decode_batch( + &encoded, + &FilterExpression::no_filter(), + &DecoderMiddlewareChain::default(), + )) + .unwrap(); + assert_eq!(data.num_rows(), batch.num_rows()); + }) + }); +} + #[cfg(target_os = "linux")] criterion_group!( name=benches; config = Criterion::default().significance_level(0.1).sample_size(10) .with_profiler(pprof::criterion::PProfProfiler::new(100, pprof::criterion::Output::Flamegraph(None))); - targets = bench_decode, bench_decode_fsl); + targets = bench_decode, bench_decode_fsl, bench_decode_str_with_dict_encoding); // Non-linux version does not support pprof. #[cfg(not(target_os = "linux"))] diff --git a/rust/lance-encoding/src/encoder.rs b/rust/lance-encoding/src/encoder.rs index 70b919bd9c..26f051aef0 100644 --- a/rust/lance-encoding/src/encoder.rs +++ b/rust/lance-encoding/src/encoder.rs @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, env, sync::Arc}; -use arrow_array::{ArrayRef, RecordBatch}; +use arrow_array::{Array, ArrayRef, RecordBatch}; use arrow_buffer::Buffer; use arrow_schema::DataType; use bytes::{Bytes, BytesMut}; @@ -19,13 +19,16 @@ use crate::{ list::ListFieldEncoder, primitive::PrimitiveFieldEncoder, r#struct::StructFieldEncoder, }, physical::{ - basic::BasicEncoder, binary::BinaryEncoder, fixed_size_list::FslEncoder, - value::ValueEncoder, + basic::BasicEncoder, binary::BinaryEncoder, dictionary::DictionaryEncoder, + fixed_size_list::FslEncoder, value::ValueEncoder, }, }, format::pb, }; +use hyperloglogplus::{HyperLogLog, HyperLogLogPlus}; +use std::collections::hash_map::RandomState; + /// An encoded buffer pub struct EncodedBuffer { /// Buffers that make up the encoded buffer @@ -232,27 +235,34 @@ impl CoreArrayEncodingStrategy { fn array_encoder_from_type( data_type: &DataType, data_size: u64, + use_dict_encoding: bool, ) -> Result> { match data_type { DataType::FixedSizeList(inner, dimension) => { Ok(Box::new(BasicEncoder::new(Box::new(FslEncoder::new( - Self::array_encoder_from_type(inner.data_type(), data_size)?, + Self::array_encoder_from_type(inner.data_type(), data_size, use_dict_encoding)?, *dimension as u32, ))))) } DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => { - let bin_indices_encoder = - Self::array_encoder_from_type(&DataType::UInt64, data_size)?; - let bin_bytes_encoder = Self::array_encoder_from_type(&DataType::UInt8, data_size)?; - - let bin_encoder = - Box::new(BinaryEncoder::new(bin_indices_encoder, bin_bytes_encoder)); - - if matches!(data_type, DataType::Utf8 | DataType::Binary) && data_size > 1024 * 1024 - { - Ok(Box::new(FsstArrayEncoder::new(bin_encoder))) + if use_dict_encoding { + let dict_indices_encoder = + Self::array_encoder_from_type(&DataType::UInt8, false)?; + let dict_items_encoder = Self::array_encoder_from_type(&DataType::Utf8, data_size, false)?; + + Ok(Box::new(DictionaryEncoder::new( + dict_indices_encoder, + dict_items_encoder, + ))) } else { - Ok(bin_encoder) + let bin_indices_encoder = + Self::array_encoder_from_type(&DataType::UInt64, data_size, false)?; + let bin_bytes_encoder = Self::array_encoder_from_type(&DataType::UInt8, data_size, false)?; + + Ok(Box::new(BinaryEncoder::new( + bin_indices_encoder, + bin_bytes_encoder, + ))) } } _ => Ok(Box::new(BasicEncoder::new(Box::new( @@ -262,13 +272,54 @@ impl CoreArrayEncodingStrategy { } } +fn get_dict_encoding_threshold() -> u64 { + env::var("LANCE_DICT_ENCODING_THRESHOLD") + .ok() + .and_then(|val| val.parse().ok()) + .unwrap_or(100) +} + +// check whether we want to use dictionary encoding or not +// by applying a threshold on cardinality +// returns true if cardinality < threshold but false if the total number of rows is less than the threshold +// The choice to use 100 is just a heuristic for now +// hyperloglog is used for cardinality estimation +// error rate = 1.04 / sqrt(2^p), where p is the precision +// and error rate is 1.04 / sqrt(2^12) = 1.56% +fn check_dict_encoding(arrays: &[ArrayRef], threshold: u64) -> bool { + let num_total_rows = arrays.iter().map(|arr| arr.len()).sum::(); + if num_total_rows < threshold as usize { + return false; + } + const PRECISION: u8 = 12; + + let mut hll: HyperLogLogPlus = + HyperLogLogPlus::new(PRECISION, RandomState::new()).unwrap(); + + for arr in arrays { + let string_array = arrow_array::cast::as_string_array(arr); + for value in string_array.iter().flatten() { + hll.insert(value); + let estimated_cardinality = hll.count() as u64; + if estimated_cardinality >= threshold { + return false; + } + } + } + + true +} + impl ArrayEncodingStrategy for CoreArrayEncodingStrategy { fn create_array_encoder(&self, arrays: &[ArrayRef]) -> Result> { let data_size = arrays .iter() .map(|arr| arr.get_buffer_memory_size() as u64) .sum::(); - Self::array_encoder_from_type(arrays[0].data_type(), data_size) + let data_type = arrays[0].data_type(); + let use_dict_encoding = data_type == &DataType::Utf8 + && check_dict_encoding(arrays, get_dict_encoding_threshold()); + Self::array_encoder_from_type(data_type, data_size, use_dict_encoding) } } @@ -568,3 +619,51 @@ pub async fn encode_batch( num_rows: batch.num_rows() as u64, }) } + +#[cfg(test)] +pub mod tests { + use arrow_array::{ArrayRef, StringArray}; + use std::sync::Arc; + + use super::check_dict_encoding; + + fn is_dict_encoding_applicable(arr: Vec>, threshold: u64) -> bool { + let arr = StringArray::from(arr); + let arr = Arc::new(arr) as ArrayRef; + check_dict_encoding(&[arr], threshold) + } + + #[test] + fn test_dict_encoding_should_be_applied_if_cardinality_less_than_threshold() { + assert!(is_dict_encoding_applicable( + vec![Some("a"), Some("b"), Some("a"), Some("b")], + 3, + )); + } + + #[test] + fn test_dict_encoding_should_not_be_applied_if_cardinality_larger_than_threshold() { + assert!(!is_dict_encoding_applicable( + vec![Some("a"), Some("b"), Some("c"), Some("d")], + 3, + )); + } + + #[test] + fn test_dict_encoding_should_not_be_applied_if_cardinality_equal_to_threshold() { + assert!(!is_dict_encoding_applicable( + vec![Some("a"), Some("b"), Some("c"), Some("a")], + 3, + )); + } + + #[test] + fn test_dict_encoding_should_not_be_applied_for_empty_arrays() { + assert!(!is_dict_encoding_applicable(vec![], 3)); + } + + #[test] + fn test_dict_encoding_should_not_be_applied_for_smaller_than_threshold_arrays() { + assert!(!is_dict_encoding_applicable(vec![Some("a"), Some("a")], 3)); + } +} diff --git a/rust/lance-encoding/src/encodings.rs b/rust/lance-encoding/src/encodings.rs index 22fc5ac875..808cf421b8 100644 --- a/rust/lance-encoding/src/encodings.rs +++ b/rust/lance-encoding/src/encodings.rs @@ -3,3 +3,4 @@ pub mod logical; pub mod physical; +pub mod utils; diff --git a/rust/lance-encoding/src/encodings/logical/primitive.rs b/rust/lance-encoding/src/encodings/logical/primitive.rs index 44dab86ffd..2c81908651 100644 --- a/rust/lance-encoding/src/encodings/logical/primitive.rs +++ b/rust/lance-encoding/src/encodings/logical/primitive.rs @@ -3,30 +3,13 @@ use std::{fmt::Debug, ops::Range, sync::Arc}; -use arrow_array::{ - new_null_array, - types::{ - ArrowPrimitiveType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, - DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Float16Type, Float32Type, Float64Type, GenericBinaryType, - GenericStringType, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, - }, - ArrayRef, BooleanArray, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, - PrimitiveArray, -}; -use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; -use arrow_schema::{DataType, IntervalUnit, TimeUnit}; -use bytes::BytesMut; +use arrow_array::{new_null_array, ArrayRef}; +use arrow_schema::DataType; use futures::{future::BoxFuture, FutureExt}; use lance_arrow::deepcopy::deep_copy_array; use log::{debug, trace}; -use snafu::{location, Location}; -use lance_core::{Error, Result}; +use lance_core::Result; use crate::{ decoder::{ @@ -38,6 +21,8 @@ use crate::{ encodings::physical::{decoder_from_array_encoding, ColumnBuffers, PageBuffers}, }; +use crate::encodings::utils::primitive_array_from_buffers; + #[derive(Debug)] struct PrimitivePage { scheduler: Box, @@ -268,276 +253,8 @@ impl DecodeArrayTask for PrimitiveFieldDecodeTask { return Ok(new_null_array(&self.data_type, self.rows_to_take as usize)); } - // Convert the two buffers into an Arrow array - Self::primitive_array_from_buffers(&self.data_type, bufs, self.rows_to_take) - } -} - -impl PrimitiveFieldDecodeTask { - // TODO: Does this capability exist upstream somewhere? I couldn't find - // it from a simple scan but it seems the ability to convert two buffers - // into a primitive array is pretty fundamental. - fn new_primitive_array( - buffers: Vec, - num_rows: u64, - data_type: &DataType, - ) -> ArrayRef { - let mut buffer_iter = buffers.into_iter(); - let null_buffer = buffer_iter.next().unwrap(); - let null_buffer = if null_buffer.is_empty() { - None - } else { - let null_buffer = null_buffer.freeze().into(); - Some(NullBuffer::new(BooleanBuffer::new( - Buffer::from_bytes(null_buffer), - 0, - num_rows as usize, - ))) - }; - - let data_buffer = buffer_iter.next().unwrap().freeze(); - let data_buffer = Buffer::from_bytes(data_buffer.into()); - let data_buffer = ScalarBuffer::::new(data_buffer, 0, num_rows as usize); - - // The with_data_type is needed here to recover the parameters for types like Decimal/Timestamp - Arc::new( - PrimitiveArray::::new(data_buffer, null_buffer).with_data_type(data_type.clone()), - ) - } - - fn new_generic_byte_array(buffers: Vec, num_rows: u64) -> ArrayRef { - // iterate over buffers to get offsets and then bytes - let mut buffer_iter = buffers.into_iter(); - - let null_buffer = buffer_iter.next().unwrap(); - let null_buffer = if null_buffer.is_empty() { - None - } else { - let null_buffer = null_buffer.freeze().into(); - Some(NullBuffer::new(BooleanBuffer::new( - Buffer::from_bytes(null_buffer), - 0, - num_rows as usize, - ))) - }; - - let indices_bytes = buffer_iter.next().unwrap().freeze(); - let indices_buffer = Buffer::from_bytes(indices_bytes.into()); - let indices_buffer = - ScalarBuffer::::new(indices_buffer, 0, num_rows as usize + 1); - - let offsets = OffsetBuffer::new(indices_buffer.clone()); - - // TODO - add NULL support - // Decoding the bytes creates 2 buffers, the first one is empty due to nulls. - buffer_iter.next().unwrap(); - - let bytes_buffer = buffer_iter.next().unwrap().freeze(); - let bytes_buffer = Buffer::from_bytes(bytes_buffer.into()); - let bytes_buffer_len = bytes_buffer.len(); - let bytes_buffer = ScalarBuffer::::new(bytes_buffer, 0, bytes_buffer_len); - - let bytes_array = Arc::new( - PrimitiveArray::::new(bytes_buffer, None).with_data_type(DataType::UInt8), - ); - - Arc::new(GenericByteArray::::new( - offsets, - bytes_array.values().into(), - null_buffer, - )) - } - - fn bytes_to_validity(bytes: BytesMut, num_rows: u64) -> Option { - if bytes.is_empty() { - None - } else { - let null_buffer = bytes.freeze().into(); - Some(NullBuffer::new(BooleanBuffer::new( - Buffer::from_bytes(null_buffer), - 0, - num_rows as usize, - ))) - } - } - - fn primitive_array_from_buffers( - data_type: &DataType, - buffers: Vec, - num_rows: u64, - ) -> Result { - match data_type { - DataType::Boolean => { - let mut buffer_iter = buffers.into_iter(); - let null_buffer = buffer_iter.next().unwrap(); - let null_buffer = Self::bytes_to_validity(null_buffer, num_rows); - - let data_buffer = buffer_iter.next().unwrap().freeze(); - let data_buffer = Buffer::from(data_buffer); - let data_buffer = BooleanBuffer::new(data_buffer, 0, num_rows as usize); - - Ok(Arc::new(BooleanArray::new(data_buffer, null_buffer))) - } - DataType::Date32 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Date64 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Decimal128(_, _) => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Decimal256(_, _) => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Duration(units) => Ok(match units { - TimeUnit::Second => { - Self::new_primitive_array::(buffers, num_rows, data_type) - } - TimeUnit::Microsecond => Self::new_primitive_array::( - buffers, num_rows, data_type, - ), - TimeUnit::Millisecond => Self::new_primitive_array::( - buffers, num_rows, data_type, - ), - TimeUnit::Nanosecond => Self::new_primitive_array::( - buffers, num_rows, data_type, - ), - }), - DataType::Float16 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Float32 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Float64 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Int16 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Int32 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Int64 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Int8 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::Interval(unit) => Ok(match unit { - IntervalUnit::DayTime => { - Self::new_primitive_array::(buffers, num_rows, data_type) - } - IntervalUnit::MonthDayNano => { - Self::new_primitive_array::( - buffers, num_rows, data_type, - ) - } - IntervalUnit::YearMonth => { - Self::new_primitive_array::(buffers, num_rows, data_type) - } - }), - DataType::Null => Ok(new_null_array(data_type, num_rows as usize)), - DataType::Time32(unit) => match unit { - TimeUnit::Millisecond => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - TimeUnit::Second => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - _ => Err(Error::io( - format!("invalid time unit {:?} for 32-bit time type", unit), - location!(), - )), - }, - DataType::Time64(unit) => match unit { - TimeUnit::Microsecond => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - TimeUnit::Nanosecond => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - _ => Err(Error::io( - format!("invalid time unit {:?} for 64-bit time type", unit), - location!(), - )), - }, - DataType::Timestamp(unit, _) => Ok(match unit { - TimeUnit::Microsecond => Self::new_primitive_array::( - buffers, num_rows, data_type, - ), - TimeUnit::Millisecond => Self::new_primitive_array::( - buffers, num_rows, data_type, - ), - TimeUnit::Nanosecond => Self::new_primitive_array::( - buffers, num_rows, data_type, - ), - TimeUnit::Second => { - Self::new_primitive_array::(buffers, num_rows, data_type) - } - }), - DataType::UInt16 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::UInt32 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::UInt64 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::UInt8 => Ok(Self::new_primitive_array::( - buffers, num_rows, data_type, - )), - DataType::FixedSizeBinary(dimension) => { - let mut buffers_iter = buffers.into_iter(); - let fsb_validity = buffers_iter.next().unwrap(); - let fsb_nulls = Self::bytes_to_validity(fsb_validity, num_rows); - - let fsb_values = buffers_iter.next().unwrap(); - let fsb_values = Buffer::from_bytes(fsb_values.freeze().into()); - Ok(Arc::new(FixedSizeBinaryArray::new( - *dimension, fsb_values, fsb_nulls, - ))) - } - DataType::FixedSizeList(items, dimension) => { - let mut buffers_iter = buffers.into_iter(); - let fsl_validity = buffers_iter.next().unwrap(); - let fsl_nulls = Self::bytes_to_validity(fsl_validity, num_rows); - - let remaining_buffers = buffers_iter.collect::>(); - let items_array = Self::primitive_array_from_buffers( - items.data_type(), - remaining_buffers, - num_rows * (*dimension as u64), - )?; - Ok(Arc::new(FixedSizeListArray::new( - items.clone(), - *dimension, - items_array, - fsl_nulls, - ))) - } - DataType::Utf8 => Ok(Self::new_generic_byte_array::>( - buffers, num_rows, - )), - DataType::LargeUtf8 => Ok(Self::new_generic_byte_array::>( - buffers, num_rows, - )), - DataType::Binary => Ok(Self::new_generic_byte_array::>( - buffers, num_rows, - )), - DataType::LargeBinary => Ok(Self::new_generic_byte_array::>( - buffers, num_rows, - )), - _ => Err(Error::io( - format!( - "The data type {} cannot be decoded from a primitive encoding", - data_type - ), - location!(), - )), - } + // Convert the buffers into an Arrow array + primitive_array_from_buffers(&self.data_type, bufs, self.rows_to_take) } } diff --git a/rust/lance-encoding/src/encodings/physical.rs b/rust/lance-encoding/src/encodings/physical.rs index f380f71637..03d8066e22 100644 --- a/rust/lance-encoding/src/encodings/physical.rs +++ b/rust/lance-encoding/src/encodings/physical.rs @@ -10,13 +10,15 @@ use crate::{decoder::PageScheduler, format::pb}; use self::value::parse_compression_scheme; use self::{ basic::BasicPageScheduler, binary::BinaryPageScheduler, bitmap::DenseBitmapScheduler, - fixed_size_list::FixedListScheduler, value::ValuePageScheduler, + dictionary::DictionaryPageScheduler, fixed_size_list::FixedListScheduler, + value::ValuePageScheduler, }; pub mod basic; pub mod binary; pub mod bitmap; pub mod buffers; +pub mod dictionary; pub mod fixed_size_list; pub mod fsst; pub mod value; @@ -153,7 +155,22 @@ pub fn decoder_from_array_encoding( let inner = decoder_from_array_encoding(fsst.binary.as_ref().unwrap(), buffers, data_type); - Box::new(FsstPageScheduler::new(inner, fsst.symbol_table.clone())) + Box::new(FsstPageScheduler::new(inner, fsst.symbol_table.clone()) + } + pb::array_encoding::ArrayEncoding::Dictionary(dictionary) => { + let indices_encoding = dictionary.indices.as_ref().unwrap(); + let items_encoding = dictionary.items.as_ref().unwrap(); + let num_dictionary_items = dictionary.num_dictionary_items; + + let indices_scheduler = + decoder_from_array_encoding(indices_encoding, buffers, data_type); + let items_scheduler = decoder_from_array_encoding(items_encoding, buffers, data_type); + + Box::new(DictionaryPageScheduler::new( + indices_scheduler.into(), + items_scheduler.into(), + num_dictionary_items, + )) } // Currently there is no way to encode struct nullability and structs are encoded with a "header" column // (that has no data). We never actually decode that column and so this branch is never actually encountered. diff --git a/rust/lance-encoding/src/encodings/physical/binary.rs b/rust/lance-encoding/src/encodings/physical/binary.rs index f593b35386..70e3744321 100644 --- a/rust/lance-encoding/src/encodings/physical/binary.rs +++ b/rust/lance-encoding/src/encodings/physical/binary.rs @@ -303,7 +303,7 @@ impl PrimitivePageDecoder for BinaryPageDecoder { let mut output_buffers = vec![validity_buffer, offsets_buf]; - // Copy decoded bytes into dest_buffers[2..] + // Add decoded bytes into output_buffers[2..] // Currently an empty null buffer is the first one // The actual bytes are in the second buffer // Including the indices this results in 4 buffers in total @@ -383,6 +383,7 @@ fn get_indices_from_string_arrays(arrays: &[ArrayRef]) -> (ArrayRef, u64) { panic!("Array is not a string array"); } } + let last_offset = *indices.last().expect("Indices array is empty"); // 8 exabytes in a single array seems unlikely but...just in case assert!( diff --git a/rust/lance-encoding/src/encodings/physical/dictionary.rs b/rust/lance-encoding/src/encodings/physical/dictionary.rs new file mode 100644 index 0000000000..8aae5dcd1d --- /dev/null +++ b/rust/lance-encoding/src/encodings/physical/dictionary.rs @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use arrow_array::builder::{ArrayBuilder, StringBuilder}; +use arrow_array::types::UInt8Type; +use arrow_array::{Array, ArrayRef, DictionaryArray, StringArray, UInt8Array}; +use futures::{future::BoxFuture, FutureExt}; + +use crate::{ + decoder::{PageScheduler, PrimitivePageDecoder}, + encoder::{ArrayEncoder, EncodedArray}, + format::pb, + EncodingsIo, +}; + +use crate::decoder::LogicalPageDecoder; +use crate::encodings::logical::primitive::PrimitiveFieldDecoder; + +use arrow_schema::DataType; +use bytes::BytesMut; +use lance_core::Result; +use std::collections::HashMap; + +use crate::encodings::utils::new_primitive_array; +use arrow_array::cast::AsArray; + +#[derive(Debug)] +pub struct DictionaryPageScheduler { + indices_scheduler: Arc, + items_scheduler: Arc, + num_dictionary_items: u32, +} + +impl DictionaryPageScheduler { + pub fn new( + indices_scheduler: Arc, + items_scheduler: Arc, + num_dictionary_items: u32, + ) -> Self { + Self { + indices_scheduler, + items_scheduler, + num_dictionary_items, + } + } +} + +impl PageScheduler for DictionaryPageScheduler { + fn schedule_ranges( + &self, + ranges: &[std::ops::Range], + scheduler: &Arc, + top_level_row: u64, + ) -> BoxFuture<'static, Result>> { + // We want to decode indices and items + // e.g. indices [0, 1, 2, 0, 1, 0] + // items (dictionary) ["abcd", "hello", "apple"] + // This will map to ["abcd", "hello", "apple", "abcd", "hello", "abcd"] + // We decode all the items during scheduling itself + // These are used to rebuild the string later + + // Schedule indices for decoding + let indices_page_decoder = + self.indices_scheduler + .schedule_ranges(ranges, scheduler, top_level_row); + + // Schedule items for decoding + let items_range = 0..(self.num_dictionary_items as u64); + let items_page_decoder = self.items_scheduler.schedule_ranges( + std::slice::from_ref(&items_range), + scheduler, + top_level_row, + ); + + let copy_size = self.num_dictionary_items as u64; + + tokio::spawn(async move { + let items_decoder: Arc = Arc::from(items_page_decoder.await?); + + let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data( + items_decoder.clone(), + DataType::Utf8, + copy_size, + ); + + // Decode all items + let drained_task = primitive_wrapper.drain(copy_size)?; + let items_decode_task = drained_task.task; + let decoded_dict = items_decode_task.decode()?; + + let indices_decoder: Box = indices_page_decoder.await?; + + Ok(Box::new(DictionaryPageDecoder { + decoded_dict, + indices_decoder, + items_decoder, + }) as Box) + }) + .map(|join_handle| join_handle.unwrap()) + .boxed() + } +} + +struct DictionaryPageDecoder { + decoded_dict: Arc, + indices_decoder: Box, + items_decoder: Arc, +} + +impl PrimitivePageDecoder for DictionaryPageDecoder { + fn decode( + &self, + rows_to_skip: u64, + num_rows: u64, + all_null: &mut bool, + ) -> Result> { + // Decode the indices + let indices_buffers = self + .indices_decoder + .decode(rows_to_skip, num_rows, all_null)?; + + let indices_array = + new_primitive_array::(indices_buffers.clone(), num_rows, &DataType::UInt8); + + let indices_array = indices_array.as_primitive::().clone(); + let dictionary = self.decoded_dict.clone(); + + let adjusted_indices: UInt8Array = indices_array + .iter() + .map(|x| match x { + Some(0) => None, + Some(x) => Some(x - 1), + None => None, + }) + .collect(); + + // Build dictionary array using indices and items + let dict_array = + DictionaryArray::::try_new(adjusted_indices, dictionary).unwrap(); + let string_array = arrow_cast::cast(&dict_array, &DataType::Utf8).unwrap(); + let string_array = string_array.as_any().downcast_ref::().unwrap(); + + // This workflow is not ideal, since we go from DictionaryArray -> StringArray -> nulls, offsets, and bytes buffers (BytesMut) + // and later in primitive_array_from_buffers() we will go from nulls, offsets, and bytes buffers -> StringArray again. + // Creating the BytesMut is an unnecessary copy. But it is the best we can do in the current structure + let null_buffer = string_array + .nulls() + .map(|n| BytesMut::from(n.buffer().as_slice())) + .unwrap_or_else(BytesMut::new); + + let offsets_buffer = BytesMut::from(string_array.offsets().inner().inner().as_slice()); + + // Empty buffer for nulls of bytes + let empty_buffer = BytesMut::new(); + + let bytes_buffer = BytesMut::from_iter(string_array.values().iter().copied()); + + Ok(vec![ + null_buffer, + offsets_buffer, + empty_buffer, + bytes_buffer, + ]) + } + + fn num_buffers(&self) -> u32 { + self.items_decoder.num_buffers() + 2 + } +} + +#[derive(Debug)] +pub struct DictionaryEncoder { + indices_encoder: Box, + items_encoder: Box, +} + +impl DictionaryEncoder { + pub fn new( + indices_encoder: Box, + items_encoder: Box, + ) -> Self { + Self { + indices_encoder, + items_encoder, + } + } +} + +fn encode_dict_indices_and_items(arrays: &[ArrayRef]) -> (ArrayRef, ArrayRef) { + let mut arr_hashmap: HashMap<&str, u8> = HashMap::new(); + // We start with a dict index of 1 because the value 0 is reserved for nulls + // The dict indices are adjusted by subtracting 1 later during decode + let mut curr_dict_index = 1; + let total_capacity = arrays.iter().map(|arr| arr.len()).sum(); + + let mut dict_indices = Vec::with_capacity(total_capacity); + let mut dict_builder = StringBuilder::new(); + + for arr in arrays.iter() { + let string_array = arrow_array::cast::as_string_array(arr); + + for i in 0..string_array.len() { + if !string_array.is_valid(i) { + // null value + dict_indices.push(0); + continue; + } + + let st = string_array.value(i); + + let hashmap_entry = *arr_hashmap.entry(st).or_insert(curr_dict_index); + dict_indices.push(hashmap_entry); + + // if item didn't exist in the hashmap, add it to the dictionary + // and increment the dictionary index + if hashmap_entry == curr_dict_index { + dict_builder.append_value(st); + curr_dict_index += 1; + } + } + } + + let array_dict_indices = Arc::new(UInt8Array::from(dict_indices)) as ArrayRef; + + // If there is an empty dictionary: + // Either there is an array of nulls or an empty array altogether + // In this case create the dictionary with a single null element + // Because decoding [] is not currently supported by the binary decoder + if dict_builder.is_empty() { + dict_builder.append_option(Option::<&str>::None); + } + + let dict_elements = dict_builder.finish(); + let array_dict_elements = arrow_cast::cast(&dict_elements, &DataType::Utf8).unwrap(); + + (array_dict_indices, array_dict_elements) +} + +impl ArrayEncoder for DictionaryEncoder { + fn encode(&self, arrays: &[ArrayRef], buffer_index: &mut u32) -> Result { + let (index_array, items_array) = encode_dict_indices_and_items(arrays); + + let encoded_indices = self + .indices_encoder + .encode(&[index_array.clone()], buffer_index)?; + + let encoded_items = self + .items_encoder + .encode(&[items_array.clone()], buffer_index)?; + + let mut encoded_buffers = encoded_indices.buffers; + encoded_buffers.extend(encoded_items.buffers); + + let dict_size = items_array.len() as u32; + + Ok(EncodedArray { + buffers: encoded_buffers, + encoding: pb::ArrayEncoding { + array_encoding: Some(pb::array_encoding::ArrayEncoding::Dictionary(Box::new( + pb::Dictionary { + indices: Some(Box::new(encoded_indices.encoding)), + items: Some(Box::new(encoded_items.encoding)), + num_dictionary_items: dict_size, + }, + ))), + }, + }) + } +} + +#[cfg(test)] +pub mod tests { + + use arrow_array::{ + builder::{LargeStringBuilder, StringBuilder}, + ArrayRef, StringArray, UInt8Array, + }; + use arrow_schema::{DataType, Field}; + use std::{sync::Arc, vec}; + + use crate::testing::{ + check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases, + }; + + use super::encode_dict_indices_and_items; + + #[test] + fn test_encode_dict_nulls() { + // Null entries in string arrays should be adjusted + let string_array1 = Arc::new(StringArray::from(vec![None, Some("foo"), Some("bar")])); + let string_array2 = Arc::new(StringArray::from(vec![Some("bar"), None, Some("foo")])); + let string_array3 = Arc::new(StringArray::from(vec![None as Option<&str>, None])); + let (dict_indices, dict_items) = + encode_dict_indices_and_items(&[string_array1, string_array2, string_array3]); + + let expected_indices = Arc::new(UInt8Array::from(vec![0, 1, 2, 2, 0, 1, 0, 0])) as ArrayRef; + let expected_items = Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef; + assert_eq!(&dict_indices, &expected_indices); + assert_eq!(&dict_items, &expected_items); + } + + #[test_log::test(tokio::test)] + async fn test_utf8() { + let field = Field::new("", DataType::Utf8, false); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_binary() { + let field = Field::new("", DataType::Binary, false); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_large_binary() { + let field = Field::new("", DataType::LargeBinary, true); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_large_utf8() { + let field = Field::new("", DataType::LargeUtf8, true); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_simple_utf8() { + let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_range(0..3) + .with_range(1..3) + .with_indices(vec![1, 3]); + check_round_trip_encoding_of_data(vec![Arc::new(string_array)], &test_cases).await; + } + + #[test_log::test(tokio::test)] + async fn test_sliced_utf8() { + let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]); + let string_array = string_array.slice(1, 3); + + let test_cases = TestCases::default() + .with_range(0..1) + .with_range(0..2) + .with_range(1..2); + check_round_trip_encoding_of_data(vec![Arc::new(string_array)], &test_cases).await; + } + + #[test_log::test(tokio::test)] + async fn test_empty_strings() { + // Scenario 1: Some strings are empty + + let values = [Some("abc"), Some(""), None]; + // Test empty list at beginning, middle, and end + for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] { + let mut string_builder = StringBuilder::new(); + for idx in order { + string_builder.append_option(values[idx]); + } + let string_array = Arc::new(string_builder.finish()); + let test_cases = TestCases::default() + .with_indices(vec![1]) + .with_indices(vec![0]) + .with_indices(vec![2]); + check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases).await; + let test_cases = test_cases.with_batch_size(1); + check_round_trip_encoding_of_data(vec![string_array], &test_cases).await; + } + + // Scenario 2: All strings are empty + + // When encoding an array of empty strings there are no bytes to encode + // which is strange and we want to ensure we handle it + let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")])); + + let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]); + check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases).await; + let test_cases = test_cases.with_batch_size(1); + check_round_trip_encoding_of_data(vec![string_array], &test_cases).await; + } + + #[test_log::test(tokio::test)] + #[ignore] // This test is quite slow in debug mode + async fn test_jumbo_string() { + // This is an overflow test. We have a list of lists where each list + // has 1Mi items. We encode 5000 of these lists and so we have over 4Gi in the + // offsets range + let mut string_builder = LargeStringBuilder::new(); + // a 1 MiB string + let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0')); + for _ in 0..5000 { + string_builder.append_option(Some(&giant_string)); + } + let giant_array = Arc::new(string_builder.finish()) as ArrayRef; + let arrs = vec![giant_array]; + + // // We can't validate because our validation relies on concatenating all input arrays + let test_cases = TestCases::default().without_validation(); + check_round_trip_encoding_of_data(arrs, &test_cases).await; + } +} diff --git a/rust/lance-encoding/src/encodings/utils.rs b/rust/lance-encoding/src/encodings/utils.rs new file mode 100644 index 0000000000..23d8b9f0d5 --- /dev/null +++ b/rust/lance-encoding/src/encodings/utils.rs @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use arrow_array::{ + new_null_array, + types::{ + ArrowPrimitiveType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Float16Type, Float32Type, Float64Type, GenericBinaryType, + GenericStringType, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, + }, + ArrayRef, BooleanArray, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, + PrimitiveArray, +}; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow_schema::{DataType, IntervalUnit, TimeUnit}; +use bytes::BytesMut; +use snafu::{location, Location}; + +use lance_core::{Error, Result}; + +pub fn new_primitive_array( + buffers: Vec, + num_rows: u64, + data_type: &DataType, +) -> ArrayRef { + let mut buffer_iter = buffers.into_iter(); + let null_buffer = buffer_iter.next().unwrap(); + let null_buffer = if null_buffer.is_empty() { + None + } else { + let null_buffer = null_buffer.freeze().into(); + Some(NullBuffer::new(BooleanBuffer::new( + Buffer::from_bytes(null_buffer), + 0, + num_rows as usize, + ))) + }; + + let data_buffer = buffer_iter.next().unwrap().freeze(); + let data_buffer = Buffer::from_bytes(data_buffer.into()); + let data_buffer = ScalarBuffer::::new(data_buffer, 0, num_rows as usize); + + // The with_data_type is needed here to recover the parameters for types like Decimal/Timestamp + Arc::new(PrimitiveArray::::new(data_buffer, null_buffer).with_data_type(data_type.clone())) +} + +pub fn new_generic_byte_array(buffers: Vec, num_rows: u64) -> ArrayRef { + // iterate over buffers to get offsets and then bytes + let mut buffer_iter = buffers.into_iter(); + + let null_buffer = buffer_iter.next().unwrap(); + let null_buffer = if null_buffer.is_empty() { + None + } else { + let null_buffer = null_buffer.freeze().into(); + Some(NullBuffer::new(BooleanBuffer::new( + Buffer::from_bytes(null_buffer), + 0, + num_rows as usize, + ))) + }; + + let indices_bytes = buffer_iter.next().unwrap().freeze(); + let indices_buffer = Buffer::from_bytes(indices_bytes.into()); + let indices_buffer = ScalarBuffer::::new(indices_buffer, 0, num_rows as usize + 1); + + let offsets = OffsetBuffer::new(indices_buffer.clone()); + + // Decoding the bytes creates 2 buffers, the first one is empty since + // validity is stored in an earlier buffer + buffer_iter.next().unwrap(); + + let bytes_buffer = buffer_iter.next().unwrap().freeze(); + let bytes_buffer = Buffer::from_bytes(bytes_buffer.into()); + let bytes_buffer_len = bytes_buffer.len(); + let bytes_buffer = ScalarBuffer::::new(bytes_buffer, 0, bytes_buffer_len); + + let bytes_array = Arc::new( + PrimitiveArray::::new(bytes_buffer, None).with_data_type(DataType::UInt8), + ); + + Arc::new(GenericByteArray::::new( + offsets, + bytes_array.values().into(), + null_buffer, + )) +} + +pub fn bytes_to_validity(bytes: BytesMut, num_rows: u64) -> Option { + if bytes.is_empty() { + None + } else { + let null_buffer = bytes.freeze().into(); + Some(NullBuffer::new(BooleanBuffer::new( + Buffer::from_bytes(null_buffer), + 0, + num_rows as usize, + ))) + } +} + +pub fn primitive_array_from_buffers( + data_type: &DataType, + buffers: Vec, + num_rows: u64, +) -> Result { + match data_type { + DataType::Boolean => { + let mut buffer_iter = buffers.into_iter(); + let null_buffer = buffer_iter.next().unwrap(); + let null_buffer = bytes_to_validity(null_buffer, num_rows); + + let data_buffer = buffer_iter.next().unwrap().freeze(); + let data_buffer = Buffer::from(data_buffer); + let data_buffer = BooleanBuffer::new(data_buffer, 0, num_rows as usize); + + Ok(Arc::new(BooleanArray::new(data_buffer, null_buffer))) + } + DataType::Date32 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Date64 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Decimal128(_, _) => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Decimal256(_, _) => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Duration(units) => Ok(match units { + TimeUnit::Second => { + new_primitive_array::(buffers, num_rows, data_type) + } + TimeUnit::Microsecond => { + new_primitive_array::(buffers, num_rows, data_type) + } + TimeUnit::Millisecond => { + new_primitive_array::(buffers, num_rows, data_type) + } + TimeUnit::Nanosecond => { + new_primitive_array::(buffers, num_rows, data_type) + } + }), + DataType::Float16 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Float32 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Float64 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Int16 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Int32 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Int64 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Int8 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::Interval(unit) => Ok(match unit { + IntervalUnit::DayTime => { + new_primitive_array::(buffers, num_rows, data_type) + } + IntervalUnit::MonthDayNano => { + new_primitive_array::(buffers, num_rows, data_type) + } + IntervalUnit::YearMonth => { + new_primitive_array::(buffers, num_rows, data_type) + } + }), + DataType::Null => Ok(new_null_array(data_type, num_rows as usize)), + DataType::Time32(unit) => match unit { + TimeUnit::Millisecond => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + TimeUnit::Second => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + _ => Err(Error::io( + format!("invalid time unit {:?} for 32-bit time type", unit), + location!(), + )), + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + TimeUnit::Nanosecond => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + _ => Err(Error::io( + format!("invalid time unit {:?} for 64-bit time type", unit), + location!(), + )), + }, + DataType::Timestamp(unit, _) => Ok(match unit { + TimeUnit::Microsecond => { + new_primitive_array::(buffers, num_rows, data_type) + } + TimeUnit::Millisecond => { + new_primitive_array::(buffers, num_rows, data_type) + } + TimeUnit::Nanosecond => { + new_primitive_array::(buffers, num_rows, data_type) + } + TimeUnit::Second => { + new_primitive_array::(buffers, num_rows, data_type) + } + }), + DataType::UInt16 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::UInt32 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::UInt64 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::UInt8 => Ok(new_primitive_array::( + buffers, num_rows, data_type, + )), + DataType::FixedSizeBinary(dimension) => { + let mut buffers_iter = buffers.into_iter(); + let fsb_validity = buffers_iter.next().unwrap(); + let fsb_nulls = bytes_to_validity(fsb_validity, num_rows); + + let fsb_values = buffers_iter.next().unwrap(); + let fsb_values = Buffer::from_bytes(fsb_values.freeze().into()); + Ok(Arc::new(FixedSizeBinaryArray::new( + *dimension, fsb_values, fsb_nulls, + ))) + } + DataType::FixedSizeList(items, dimension) => { + let mut buffers_iter = buffers.into_iter(); + let fsl_validity = buffers_iter.next().unwrap(); + let fsl_nulls = bytes_to_validity(fsl_validity, num_rows); + + let remaining_buffers = buffers_iter.collect::>(); + let items_array = primitive_array_from_buffers( + items.data_type(), + remaining_buffers, + num_rows * (*dimension as u64), + )?; + Ok(Arc::new(FixedSizeListArray::new( + items.clone(), + *dimension, + items_array, + fsl_nulls, + ))) + } + DataType::Utf8 => Ok(new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::LargeUtf8 => Ok(new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::Binary => Ok(new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::LargeBinary => Ok(new_generic_byte_array::>( + buffers, num_rows, + )), + _ => Err(Error::io( + format!( + "The data type {} cannot be decoded from a primitive encoding", + data_type + ), + location!(), + )), + } +} diff --git a/rust/lance-file/benches/reader.rs b/rust/lance-file/benches/reader.rs index 23351d96de..7c87fda4c5 100644 --- a/rust/lance-file/benches/reader.rs +++ b/rust/lance-file/benches/reader.rs @@ -29,7 +29,6 @@ fn bench_reader(c: &mut Criterion) { let mut writer = FileWriter::try_new( object_writer, - file_path.to_string(), data.schema().as_ref().try_into().unwrap(), FileWriterOptions::default(), ) diff --git a/rust/lance-file/src/v2/reader.rs b/rust/lance-file/src/v2/reader.rs index 3fe182a0d1..68ead50160 100644 --- a/rust/lance-file/src/v2/reader.rs +++ b/rust/lance-file/src/v2/reader.rs @@ -1388,7 +1388,6 @@ pub mod tests { let mut file_writer = FileWriter::try_new( fs.object_store.create(&fs.tmp_path).await.unwrap(), - fs.tmp_path.to_string(), lance_schema.clone(), FileWriterOptions::default(), ) diff --git a/rust/lance-file/src/v2/testing.rs b/rust/lance-file/src/v2/testing.rs index aa807a47a5..18b00e4160 100644 --- a/rust/lance-file/src/v2/testing.rs +++ b/rust/lance-file/src/v2/testing.rs @@ -49,13 +49,7 @@ pub async fn write_lance_file( let lance_schema = lance_core::datatypes::Schema::try_from(data.schema().as_ref()).unwrap(); - let mut file_writer = FileWriter::try_new( - writer, - fs.tmp_path.to_string(), - lance_schema.clone(), - options, - ) - .unwrap(); + let mut file_writer = FileWriter::try_new(writer, lance_schema.clone(), options).unwrap(); let data = data .collect::, ArrowError>>() diff --git a/rust/lance-file/src/v2/writer.rs b/rust/lance-file/src/v2/writer.rs index 4ea3971a6b..3a585a497f 100644 --- a/rust/lance-file/src/v2/writer.rs +++ b/rust/lance-file/src/v2/writer.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; use std::sync::Arc; use arrow_array::RecordBatch; @@ -70,14 +71,15 @@ pub struct FileWriterOptions { pub struct FileWriter { writer: ObjectWriter, - path: String, - schema: LanceSchema, + schema: Option, column_writers: Vec>, column_metadata: Vec, field_id_to_column_indices: Vec<(i32, i32)>, num_columns: u32, rows_written: u64, global_buffers: Vec<(u64, u64)>, + schema_metadata: HashMap, + options: FileWriterOptions, } fn initial_column_metadata() -> pbfile::ColumnMetadata { @@ -90,48 +92,34 @@ fn initial_column_metadata() -> pbfile::ColumnMetadata { } impl FileWriter { - /// Create a new FileWriter + /// Create a new FileWriter with a desired output schema pub fn try_new( object_writer: ObjectWriter, - path: String, schema: LanceSchema, options: FileWriterOptions, ) -> Result { - let cache_bytes_per_column = if let Some(data_cache_bytes) = options.data_cache_bytes { - data_cache_bytes / schema.fields.len() as u64 - } else { - 8 * 1024 * 1024 - }; - - schema.validate()?; - - let keep_original_array = options.keep_original_array.unwrap_or(false); - let encoding_strategy = options - .encoding_strategy - .unwrap_or_else(|| Arc::new(CoreFieldEncodingStrategy::default())); - - let encoder = BatchEncoder::try_new( - &schema, - encoding_strategy.as_ref(), - cache_bytes_per_column, - keep_original_array, - )?; - let num_columns = encoder.num_columns(); - - let column_writers = encoder.field_encoders; - let column_metadata = vec![initial_column_metadata(); num_columns as usize]; + let mut writer = Self::new_lazy(object_writer, options); + writer.initialize(schema)?; + Ok(writer) + } - Ok(Self { + /// Create a new FileWriter without a desired output schema + /// + /// The output schema will be set based on the first batch of data to arrive. + /// If no data arrives and the writer is finished then the write will fail. + pub fn new_lazy(object_writer: ObjectWriter, options: FileWriterOptions) -> Self { + Self { writer: object_writer, - path, - schema, - column_writers, - column_metadata, - num_columns, + schema: None, + column_writers: Vec::new(), + column_metadata: Vec::new(), + num_columns: 0, rows_written: 0, - field_id_to_column_indices: encoder.field_id_to_column_index, + field_id_to_column_indices: Vec::new(), global_buffers: Vec::new(), - }) + schema_metadata: HashMap::new(), + options, + } } async fn write_page(&mut self, encoded_page: EncodedPage) -> Result<()> { @@ -208,6 +196,47 @@ impl FileWriter { Ok(()) } + fn initialize(&mut self, mut schema: LanceSchema) -> Result<()> { + let cache_bytes_per_column = if let Some(data_cache_bytes) = self.options.data_cache_bytes { + data_cache_bytes / schema.fields.len() as u64 + } else { + 8 * 1024 * 1024 + }; + + schema.validate()?; + + let keep_original_array = self.options.keep_original_array.unwrap_or(false); + let encoding_strategy = self + .options + .encoding_strategy + .clone() + .unwrap_or_else(|| Arc::new(CoreFieldEncodingStrategy::default())); + + let encoder = BatchEncoder::try_new( + &schema, + encoding_strategy.as_ref(), + cache_bytes_per_column, + keep_original_array, + )?; + self.num_columns = encoder.num_columns(); + + self.column_writers = encoder.field_encoders; + self.column_metadata = vec![initial_column_metadata(); self.num_columns as usize]; + self.field_id_to_column_indices = encoder.field_id_to_column_index; + self.schema_metadata + .extend(std::mem::take(&mut schema.metadata)); + self.schema = Some(schema); + Ok(()) + } + + fn ensure_initialized(&mut self, batch: &RecordBatch) -> Result<&LanceSchema> { + if self.schema.is_none() { + let schema = LanceSchema::try_from(batch.schema().as_ref())?; + self.initialize(schema)?; + } + Ok(self.schema.as_ref().unwrap()) + } + /// Schedule a batch of data to be written to the file /// /// Note: the future returned by this method may complete before the data has been fully @@ -217,6 +246,8 @@ impl FileWriter { "write_batch called with {} bytes of data", batch.get_array_memory_size() ); + self.ensure_initialized(batch)?; + let schema = self.schema.as_ref().unwrap(); let num_rows = batch.num_rows() as u64; if num_rows == 0 { return Ok(()); @@ -235,8 +266,7 @@ impl FileWriter { }; // First we push each array into its column writer. This may or may not generate enough // data to trigger an encoding task. We collect any encoding tasks into a queue. - let encoding_tasks = self - .schema + let encoding_tasks = schema .fields .iter() .zip(self.column_writers.iter_mut()) @@ -300,7 +330,9 @@ impl FileWriter { } async fn write_global_buffers(&mut self) -> Result> { - let file_descriptor = Self::make_file_descriptor(&self.schema, self.rows_written)?; + let schema = self.schema.as_mut().ok_or(Error::invalid_input("No schema provided on writer open and no data provided. Schema is unknown and file cannot be created", location!()))?; + schema.metadata = std::mem::take(&mut self.schema_metadata); + let file_descriptor = Self::make_file_descriptor(schema, self.rows_written)?; let file_descriptor_bytes = file_descriptor.encode_to_vec(); let file_descriptor_len = file_descriptor_bytes.len() as u64; let file_descriptor_position = self.writer.tell().await? as u64; @@ -317,7 +349,7 @@ impl FileWriter { /// data has been written. This method allows you to alter the schema metadata. It /// must be called before `finish` is called. pub fn add_schema_metadata(&mut self, key: impl Into, value: impl Into) { - self.schema.metadata.insert(key.into(), value.into()); + self.schema_metadata.insert(key.into(), value.into()); } /// Adds a global buffer to the file @@ -450,10 +482,6 @@ impl FileWriter { pub fn field_id_to_column_indices(&self) -> &[(i32, i32)] { &self.field_id_to_column_indices } - - pub fn path(&self) -> &str { - &self.path - } } /// Utility trait for converting EncodedBatch to Bytes using the @@ -599,13 +627,8 @@ mod tests { let lance_schema = lance_core::datatypes::Schema::try_from(reader.schema().as_ref()).unwrap(); - let mut file_writer = FileWriter::try_new( - writer, - tmp_path.to_string(), - lance_schema, - FileWriterOptions::default(), - ) - .unwrap(); + let mut file_writer = + FileWriter::try_new(writer, lance_schema, FileWriterOptions::default()).unwrap(); for batch in reader { file_writer.write_batch(&batch.unwrap()).await.unwrap(); @@ -632,13 +655,8 @@ mod tests { let lance_schema = lance_core::datatypes::Schema::try_from(reader.schema().as_ref()).unwrap(); - let mut file_writer = FileWriter::try_new( - writer, - tmp_path.to_string(), - lance_schema, - FileWriterOptions::default(), - ) - .unwrap(); + let mut file_writer = + FileWriter::try_new(writer, lance_schema, FileWriterOptions::default()).unwrap(); for batch in reader { file_writer.write_batch(&batch.unwrap()).await.unwrap(); diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 0991ae6e00..acafed05e3 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -49,10 +49,12 @@ rayon.workspace = true serde_json.workspace = true serde.workspace = true snafu.workspace = true +tantivy.workspace = true tokio.workspace = true tracing.workspace = true tempfile.workspace = true crossbeam-queue.workspace = true +bytes.workspace = true [dev-dependencies] approx.workspace = true @@ -60,7 +62,6 @@ clap = { workspace = true, features = ["derive"] } criterion.workspace = true lance-datagen.workspace = true lance-testing.workspace = true -pprof.workspace = true tempfile.workspace = true datafusion-sql.workspace = true @@ -68,6 +69,9 @@ datafusion-sql.workspace = true prost-build.workspace = true rustc_version.workspace = true +[target.'cfg(target_os = "linux")'.dev-dependencies] +pprof.workspace = true + [[bench]] name = "find_partitions" harness = false diff --git a/rust/lance-index/benches/pq_assignment.rs b/rust/lance-index/benches/pq_assignment.rs index bf448633f7..312358c091 100644 --- a/rust/lance-index/benches/pq_assignment.rs +++ b/rust/lance-index/benches/pq_assignment.rs @@ -3,12 +3,11 @@ //! Benchmark of Building PQ code from Dense Vectors. -use std::sync::Arc; - use arrow_array::{types::Float32Type, FixedSizeListArray}; use criterion::{criterion_group, criterion_main, Criterion}; use lance_arrow::FixedSizeListArrayExt; -use lance_index::vector::pq::{ProductQuantizer, ProductQuantizerImpl}; +use lance_index::vector::pq::ProductQuantizer; +use lance_index::vector::quantizer::Quantization; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_seed; @@ -17,20 +16,23 @@ const DIM: usize = 1536; const TOTAL: usize = 32 * 1024; fn pq_transform(c: &mut Criterion) { - let codebook = Arc::new(generate_random_array_with_seed::( - 256 * DIM, - [88; 32], - )); + let codebook = generate_random_array_with_seed::(256 * DIM, [88; 32]); let vectors = generate_random_array_with_seed::(DIM * TOTAL, [3; 32]); let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); for dt in [DistanceType::L2, DistanceType::Dot].iter() { - let pq = ProductQuantizerImpl::::new(PQ, 8, DIM, codebook.clone(), *dt); + let pq = ProductQuantizer::new( + PQ, + 8, + DIM, + FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(), + *dt, + ); c.bench_function(format!("{},{}", dt, TOTAL).as_str(), |b| { b.iter(|| { - let _ = pq.transform(&fsl).unwrap(); + let _ = pq.quantize(&fsl).unwrap(); }) }); } diff --git a/rust/lance-index/benches/pq_dist_table.rs b/rust/lance-index/benches/pq_dist_table.rs index d46d15f129..d236c59f67 100644 --- a/rust/lance-index/benches/pq_dist_table.rs +++ b/rust/lance-index/benches/pq_dist_table.rs @@ -4,13 +4,13 @@ //! Benchmark of building PQ distance table. use std::iter::repeat; -use std::sync::Arc; use arrow_array::types::Float32Type; -use arrow_array::UInt8Array; +use arrow_array::{FixedSizeListArray, UInt8Array}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use lance_index::vector::pq::{ProductQuantizer, ProductQuantizerImpl}; -use lance_linalg::distance::MetricType; +use lance_arrow::FixedSizeListArrayExt; +use lance_index::vector::pq::ProductQuantizer; +use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_seed; use rand::{prelude::StdRng, Rng, SeedableRng}; @@ -22,17 +22,19 @@ const DIM: usize = 1536; const TOTAL: usize = 5 * 1024 * 1024; fn dist_table(c: &mut Criterion) { - let codebook = Arc::new(generate_random_array_with_seed::( - 256 * DIM, - [88; 32], - )); + let codebook = generate_random_array_with_seed::(256 * DIM, [88; 32]); let query = generate_random_array_with_seed::(DIM, [32; 32]); let mut rnd = StdRng::from_seed([32; 32]); let code = UInt8Array::from_iter_values(repeat(rnd.gen::()).take(TOTAL * PQ)); - let l2_pq = - ProductQuantizerImpl::::new(PQ, 8, DIM, codebook.clone(), MetricType::L2); + let l2_pq = ProductQuantizer::new( + PQ, + 8, + DIM, + FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(), + DistanceType::L2, + ); c.bench_function( format!("{},L2,PQ={},DIM={}", TOTAL, PQ, DIM).as_str(), @@ -43,8 +45,13 @@ fn dist_table(c: &mut Criterion) { }, ); - let cosine_pq = - ProductQuantizerImpl::::new(PQ, 8, DIM, codebook.clone(), MetricType::Cosine); + let cosine_pq = ProductQuantizer::new( + PQ, + 8, + DIM, + FixedSizeListArray::try_new_from_values(codebook.clone(), DIM as i32).unwrap(), + DistanceType::Cosine, + ); c.bench_function( format!("{},Cosine,PQ={},DIM={}", TOTAL, PQ, DIM).as_str(), diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 11fdda90ea..6809cdedf0 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -18,9 +18,11 @@ use lance_core::Result; use crate::Index; +pub mod bitmap; pub mod btree; pub mod expression; pub mod flat; +pub mod inverted; pub mod lance_format; /// Trait for storing an index (or parts of an index) into storage @@ -82,6 +84,8 @@ pub enum ScalarQuery { IsIn(Vec), /// Retrieve all row ids where the value is exactly the given value Equals(ScalarValue), + /// Retrieve all row ids where the value matches the given full text search query + FullTextSearch(Vec), /// Retrieve all row ids where the value is null IsNull(), } @@ -125,6 +129,9 @@ impl ScalarQuery { .collect::>(), false, ), + Self::FullTextSearch(values) => { + col_expr.like(Expr::Literal(ScalarValue::Utf8(Some(values.join("|"))))) + } Self::IsNull() => col_expr.is_null(), Self::Equals(value) => col_expr.eq(Expr::Literal(value.clone())), } @@ -162,6 +169,9 @@ impl ScalarQuery { .join(",") ) } + Self::FullTextSearch(values) => { + format!("{} LIKE '{}'", col, values.join("|")) + } Self::IsNull() => { format!("{} IS NULL", col) } diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs new file mode 100644 index 0000000000..cc88aabad2 --- /dev/null +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -0,0 +1,356 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{ + any::Any, + collections::{BTreeMap, HashMap}, + fmt::Debug, + ops::Bound, + sync::Arc, +}; + +use arrow::array::{BinaryBuilder, UInt64Builder}; +use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::physical_plan::SendableRecordBatchStream; +use datafusion_common::ScalarValue; +use deepsize::DeepSizeOf; +use futures::TryStreamExt; +use lance_core::{Error, Result}; +use roaring::treemap::RoaringTreemap; +use roaring::RoaringBitmap; +use serde::Serialize; +use snafu::{location, Location}; + +use crate::{Index, IndexType}; + +use super::btree::OrderableScalarValue; +use super::{btree::BtreeTrainingSource, IndexStore, ScalarIndex, ScalarQuery}; + +pub const BITMAP_LOOKUP_NAME: &str = "bitmap_page_lookup.lance"; + +/// A scalar index that stores a bitmap for each possible value +/// +/// This index works best for low-cardinality columns, where the number of unique values is small. +/// The bitmap stores a list of row ids where the value is present. +#[derive(Clone, Debug)] +pub struct BitmapIndex { + index_map: BTreeMap, + // Memoized index_map size for DeepSizeOf + index_map_size_bytes: usize, + store: Arc, +} + +impl BitmapIndex { + fn new( + index_map: BTreeMap, + index_map_size_bytes: usize, + store: Arc, + ) -> Self { + Self { + index_map, + index_map_size_bytes, + store, + } + } + + // creates a new BitmapIndex from a serialized RecordBatch + fn try_from_serialized(data: RecordBatch, store: Arc) -> Result { + if data.num_rows() == 0 { + return Err(Error::Internal { + message: "attempt to load bitmap index from empty record batch".into(), + location: location!(), + }); + } + + let dict_keys = data.column(0); + let binary_bitmaps = data.column(1); + let bitmap_binary_array = binary_bitmaps + .as_any() + .downcast_ref::() + .unwrap(); + + let mut index_map: BTreeMap = BTreeMap::new(); + + let mut index_map_size_bytes = 0; + for idx in 0..data.num_rows() { + let key = OrderableScalarValue(ScalarValue::try_from_array(dict_keys, idx)?); + let bitmap_bytes = bitmap_binary_array.value(idx); + let bitmap = RoaringTreemap::deserialize_from(bitmap_bytes).unwrap(); + let bitmap_vec: Vec = bitmap.into_iter().collect(); + let bitmap_array = UInt64Array::from(bitmap_vec); + + index_map_size_bytes += key.deep_size_of(); + index_map_size_bytes += bitmap_array.get_array_memory_size(); + index_map.insert(key, bitmap_array); + } + + Ok(Self::new(index_map, index_map_size_bytes, store)) + } +} + +impl DeepSizeOf for BitmapIndex { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + let mut total_size = 0; + + // Size of BTreeMap values + total_size += self.index_map_size_bytes; + + // Size of Arc contents + total_size += self.store.deep_size_of_children(context); + + total_size + } +} + +#[derive(Serialize)] +struct BitmapStatistics { + num_bitmaps: usize, +} + +#[async_trait] +impl Index for BitmapIndex { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_index(self: Arc) -> Arc { + self + } + + fn as_vector_index(self: Arc) -> Result> { + Err(Error::NotSupported { + source: "BitmapIndex is not a vector index".into(), + location: location!(), + }) + } + + fn index_type(&self) -> IndexType { + IndexType::Scalar + } + + fn statistics(&self) -> Result { + let stats = BitmapStatistics { + num_bitmaps: self.index_map.len(), + }; + serde_json::to_value(stats).map_err(|e| Error::Internal { + message: format!("failed to serialize bitmap index statistics: {}", e), + location: location!(), + }) + } + + async fn calculate_included_frags(&self) -> Result { + unimplemented!() + } +} + +#[async_trait] +impl ScalarIndex for BitmapIndex { + async fn search(&self, query: &ScalarQuery) -> Result { + let empty_vec: Vec = Vec::new(); + let empty_array = UInt64Array::from(empty_vec); + + let row_ids = match query { + ScalarQuery::Equals(val) => { + let key = OrderableScalarValue(val.clone()); + self.index_map.get(&key).unwrap_or(&empty_array).clone() + } + ScalarQuery::Range(start, end) => { + let range_start = match start { + Bound::Included(val) => Bound::Included(OrderableScalarValue(val.clone())), + Bound::Excluded(val) => Bound::Excluded(OrderableScalarValue(val.clone())), + Bound::Unbounded => Bound::Unbounded, + }; + + let range_end = match end { + Bound::Included(val) => Bound::Included(OrderableScalarValue(val.clone())), + Bound::Excluded(val) => Bound::Excluded(OrderableScalarValue(val.clone())), + Bound::Unbounded => Bound::Unbounded, + }; + + let range_iter = self.index_map.range((range_start, range_end)); + let total_len: usize = range_iter.clone().map(|(_, arr)| arr.len()).sum(); + let mut builder = UInt64Builder::with_capacity(total_len); + + for (_, array) in range_iter { + builder.append_slice(array.values()); + } + + builder.finish() + } + ScalarQuery::IsIn(values) => { + let mut builder = UInt64Builder::new(); + for val in values { + let key = OrderableScalarValue(val.clone()); + if let Some(array) = self.index_map.get(&key) { + builder.append_slice(array.values()); + } + } + + builder.finish() + } + ScalarQuery::IsNull() => { + if let Some(array) = self + .index_map + .iter() + .find(|(key, _)| key.0.is_null()) + .map(|(_, value)| value) + { + array.clone() + } else { + empty_array + } + } + ScalarQuery::FullTextSearch(_) => { + return Err(Error::NotSupported { + source: "full text search is not supported for bitmap indexes".into(), + location: location!(), + }); + } + }; + + Ok(row_ids) + } + + async fn load(store: Arc) -> Result> { + let page_lookup_file = store.open_index_file(BITMAP_LOOKUP_NAME).await?; + let serialized_lookup = page_lookup_file.read_record_batch(0).await?; + + Ok(Arc::new(Self::try_from_serialized( + serialized_lookup, + store, + )?)) + } + + /// Remap the row ids, creating a new remapped version of this index in `dest_store` + async fn remap( + &self, + mapping: &HashMap>, + dest_store: &dyn IndexStore, + ) -> Result<()> { + let state = self + .index_map + .iter() + .map(|(key, bitmap)| { + let bitmap = bitmap + .values() + .iter() + .filter_map(|row_id| *mapping.get(row_id)?) + .collect::>(); + (key.0.clone(), bitmap) + }) + .collect::>(); + write_bitmap_index(state, dest_store).await + } + + /// Add the new data into the index, creating an updated version of the index in `dest_store` + async fn update( + &self, + new_data: SendableRecordBatchStream, + dest_store: &dyn IndexStore, + ) -> Result<()> { + let state = self + .index_map + .iter() + .map(|(key, bitmap)| { + ( + key.0.clone(), + Vec::from_iter(bitmap.values().iter().copied()), + ) + }) + .collect::>(); + do_train_bitmap_index(new_data, state, dest_store).await + } +} + +fn get_batch_from_arrays( + keys: Arc, + binary_bitmaps: Arc, +) -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("keys", keys.data_type().clone(), true), + Field::new("bitmaps", binary_bitmaps.data_type().clone(), true), + ])); + + let columns = vec![keys, binary_bitmaps]; + + Ok(RecordBatch::try_new(schema, columns)?) +} + +// Takes an iterator of Vec and processes each vector +// to turn it into a RoaringTreemap. Each RoaringTreeMap is +// serialized to bytes. The entire collection is converted to a BinaryArray +fn get_bitmaps_from_iter(iter: I) -> Arc +where + I: Iterator>, +{ + let mut builder = BinaryBuilder::new(); + iter.for_each(|vec| { + let mut bitmap = RoaringTreemap::new(); + bitmap.extend(vec); + let mut bytes = Vec::new(); + bitmap.serialize_into(&mut bytes).unwrap(); + builder.append_value(&bytes); + }); + + Arc::new(builder.finish()) +} + +async fn write_bitmap_index( + state: HashMap>, + index_store: &dyn IndexStore, +) -> Result<()> { + let keys_iter = state.keys().cloned(); + let keys_array = ScalarValue::iter_to_array(keys_iter)?; + + let values_iter = state.into_values(); + let binary_bitmap_array = get_bitmaps_from_iter(values_iter); + + let record_batch = get_batch_from_arrays(keys_array, binary_bitmap_array)?; + + let mut bitmap_index_file = index_store + .new_index_file(BITMAP_LOOKUP_NAME, record_batch.schema()) + .await?; + bitmap_index_file.write_record_batch(record_batch).await?; + bitmap_index_file.finish().await?; + Ok(()) +} + +async fn do_train_bitmap_index( + mut data_source: SendableRecordBatchStream, + mut state: HashMap>, + index_store: &dyn IndexStore, +) -> Result<()> { + while let Some(batch) = data_source.try_next().await? { + debug_assert_eq!(batch.num_columns(), 2); + debug_assert_eq!(*batch.column(1).data_type(), DataType::UInt64); + + let key_column = batch.column(0); + let row_id_column = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..key_column.len() { + let row_id = row_id_column.value(i); + let key = ScalarValue::try_from_array(key_column.as_ref(), i)?; + state.entry(key.clone()).or_default().push(row_id); + } + } + + write_bitmap_index(state, index_store).await +} + +pub async fn train_bitmap_index( + data_source: Box, + index_store: &dyn IndexStore, +) -> Result<()> { + let batches_source = data_source.scan_ordered_chunks(4096).await?; + + // mapping from item to list of the row ids where it is present + let dictionary: HashMap> = HashMap::new(); + + do_train_bitmap_index(batches_source, dictionary, index_store).await +} diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 752d56cc6d..887d360900 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -49,7 +49,7 @@ const BTREE_PAGES_NAME: &str = "page_data.lance"; /// Wraps a ScalarValue and implements Ord (ScalarValue only implements PartialOrd) #[derive(Clone, Debug)] -struct OrderableScalarValue(ScalarValue); +pub struct OrderableScalarValue(pub ScalarValue); impl DeepSizeOf for OrderableScalarValue { fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { @@ -846,6 +846,10 @@ impl ScalarIndex for BTreeIndex { ScalarQuery::IsIn(values) => self .page_lookup .pages_in(values.iter().map(|val| OrderableScalarValue(val.clone()))), + ScalarQuery::FullTextSearch(_) => return Err(Error::invalid_input( + "full text search is not supported for BTree index, build a inverted index for it", + location!(), + )), ScalarQuery::IsNull() => self.page_lookup.pages_null(), }; let sub_index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; diff --git a/rust/lance-index/src/scalar/flat.rs b/rust/lance-index/src/scalar/flat.rs index 842773929a..4948f69973 100644 --- a/rust/lance-index/src/scalar/flat.rs +++ b/rust/lance-index/src/scalar/flat.rs @@ -248,6 +248,10 @@ impl ScalarIndex for FlatIndex { &arrow_ord::cmp::lt(self.values(), &upper.to_scalar()?)?, )?, }, + ScalarQuery::FullTextSearch(_) => return Err(Error::invalid_input( + "full text search is not supported for flat index, build a inverted index for it", + location!(), + )), }; Ok(arrow_select::filter::filter(self.ids(), &predicate)? .as_any() diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs new file mode 100644 index 0000000000..53ce44f77a --- /dev/null +++ b/rust/lance-index/src/scalar/inverted.rs @@ -0,0 +1,553 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::{AsArray, ListBuilder, UInt64Builder}; +use arrow::datatypes; +use arrow_array::{ArrayRef, RecordBatch, StringArray, UInt32Array, UInt64Array}; +use arrow_schema::{DataType, Field}; +use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; +use deepsize::DeepSizeOf; +use futures::{StreamExt, TryStreamExt}; +use itertools::Itertools; +use lance_core::{Error, Result, ROW_ID}; +use roaring::RoaringBitmap; +use snafu::{location, Location}; + +use crate::vector::graph::OrderedFloat; +use crate::Index; + +use super::{IndexReader, IndexStore, ScalarIndex, ScalarQuery}; + +const TOKENS_FILE: &str = "tokens.lance"; +const INVERT_LIST_FILE: &str = "invert.lance"; +const DOCS_FILE: &str = "docs.lance"; + +const TOKEN_COL: &str = "_token"; +const TOKEN_ID_COL: &str = "_token_id"; +const FREQUENCY_COL: &str = "_frequency"; +const NUM_TOKEN_COL: &str = "_num_tokens"; + +#[derive(Debug, Clone, Default, DeepSizeOf)] +pub struct InvertedIndex { + tokens: TokenSet, + invert_list: InvertedList, + docs: DocSet, +} + +impl InvertedIndex { + // map tokens to token ids + // ignore tokens that are not in the index cause they won't contribute to the search + fn map(&self, texts: &[String]) -> Vec { + texts + .iter() + .filter_map(|text| self.tokens.get(text)) + .collect() + } + + // search the documents that contain the query + // return the row ids of the documents sorted by bm25 score + fn bm25_search(&self, token_ids: Vec) -> Vec<(u64, f32)> { + const K1: f32 = 1.2; + const B: f32 = 0.75; + + let avgdl = self.docs.average_length(); + let mut bm25 = HashMap::new(); + + token_ids + .into_iter() + .filter_map(|token| self.invert_list.retrieve(token)) + .for_each(|(row_ids, freq)| { + // TODO: this can be optimized by parallelizing the calculation + row_ids + .iter() + .zip(freq.iter()) + .for_each(|(&row_id, &freq)| { + let freq = freq as f32; + let bm25 = bm25.entry(row_id).or_insert(0.0); + *bm25 += self.idf(row_ids.len()) * freq * (K1 + 1.0) + / (freq + + K1 * (1.0 - B + + B * self.docs.num_tokens[row_id as usize] as f32 / avgdl)); + }); + }); + + bm25.into_iter() + .sorted_unstable_by_key(|r| OrderedFloat(-r.1)) + .collect_vec() + } + + #[inline] + fn idf(&self, nq: usize) -> f32 { + let num_docs = self.docs.row_ids.len() as f32; + ((num_docs - nq as f32 + 0.5) / (nq as f32 + 0.5) + 1.0).ln() + } +} + +#[async_trait] +impl Index for InvertedIndex { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_index(self: Arc) -> Arc { + self + } + + fn as_vector_index(self: Arc) -> Result> { + Err(Error::invalid_input( + "inverted index cannot be cast to vector index", + location!(), + )) + } + + fn statistics(&self) -> Result { + Ok(serde_json::json!({ + "num_tokens": self.tokens.tokens.len(), + "num_docs": self.docs.row_ids.len(), + })) + } + + fn index_type(&self) -> crate::IndexType { + crate::IndexType::Scalar + } + + async fn calculate_included_frags(&self) -> Result { + unimplemented!() + } +} + +#[async_trait] +impl ScalarIndex for InvertedIndex { + // return the row ids of the documents that contain the query + async fn search(&self, query: &ScalarQuery) -> Result { + let row_ids = match query { + ScalarQuery::FullTextSearch(tokens) => { + let token_ids = self.map(tokens); + self.bm25_search(token_ids) + .into_iter() + .map(|(row_id, _)| row_id) + } + query => { + return Err(Error::invalid_input( + format!("unsupported query {:?} for inverted index", query), + location!(), + )) + } + }; + + // sort the row ids (documents) by bm25 score + + Ok(UInt64Array::from_iter_values(row_ids)) + } + + async fn load(store: Arc) -> Result> + where + Self: Sized, + { + let token_reader = store.open_index_file(TOKENS_FILE).await?; + let invert_list_reader = store.open_index_file(INVERT_LIST_FILE).await?; + let docs_reader = store.open_index_file(DOCS_FILE).await?; + + let tokens = TokenSet::load(token_reader).await?; + let invert_list = InvertedList::load(invert_list_reader).await?; + let docs = DocSet::load(docs_reader).await?; + + Ok(Arc::new(Self { + tokens, + invert_list, + docs, + })) + } + + async fn remap( + &self, + _mapping: &HashMap>, + _dest_store: &dyn IndexStore, + ) -> Result<()> { + unimplemented!() + } + + async fn update( + &self, + new_data: SendableRecordBatchStream, + dest_store: &dyn IndexStore, + ) -> Result<()> { + let mut token_set = self.tokens.clone(); + let mut invert_list = self.invert_list.clone(); + let mut docs = self.docs.clone(); + let mut tokenizer = tantivy::tokenizer::TextAnalyzer::builder( + tantivy::tokenizer::SimpleTokenizer::default(), + ) + .build(); + let mut stream = new_data.peekable(); + while let Some(batch) = stream.try_next().await? { + let doc_col = batch.column(0).as_string::(); + let row_id_col = batch[ROW_ID].as_primitive::(); + + for (doc, row_id) in doc_col.iter().zip(row_id_col.iter()) { + let doc = doc.unwrap(); + let row_id = row_id.unwrap(); + let mut token_stream = tokenizer.token_stream(doc); + let mut token_cnt = 0; + while let Some(token) = token_stream.next() { + let token_id = token_set.add(token.text.clone()); + invert_list.add(token_id, row_id); + token_cnt += 1; + } + docs.add(row_id, token_cnt); + } + } + + let token_set_batch = token_set.to_batch()?; + let mut token_set_writer = dest_store + .new_index_file(TOKENS_FILE, token_set_batch.schema()) + .await?; + token_set_writer.write_record_batch(token_set_batch).await?; + token_set_writer.finish().await?; + + let invert_list_batch = invert_list.to_batch()?; + let mut invert_list_writer = dest_store + .new_index_file(INVERT_LIST_FILE, invert_list_batch.schema()) + .await?; + invert_list_writer + .write_record_batch(invert_list_batch) + .await?; + invert_list_writer.finish().await?; + + let docs_batch = docs.to_batch()?; + let mut docs_writer = dest_store + .new_index_file(DOCS_FILE, docs_batch.schema()) + .await?; + docs_writer.write_record_batch(docs_batch).await?; + docs_writer.finish().await?; + + Ok(()) + } +} + +// TokenSet is a mapping from tokens to token ids +// it also records the frequency of each token +#[derive(Debug, Clone, Default, DeepSizeOf)] +struct TokenSet { + tokens: Vec, + ids: Vec, + frequencies: Vec, +} + +impl TokenSet { + fn to_batch(&self) -> Result { + let token_col = StringArray::from(self.tokens.clone()); + let token_id_col = UInt32Array::from(self.ids.clone()); + let frequency_col = UInt64Array::from(self.frequencies.clone()); + + let schema = arrow_schema::Schema::new(vec![ + arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false), + arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false), + arrow_schema::Field::new(FREQUENCY_COL, DataType::UInt64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(token_col) as ArrayRef, + Arc::new(token_id_col) as ArrayRef, + Arc::new(frequency_col) as ArrayRef, + ], + )?; + Ok(batch) + } + + async fn load(reader: Arc) -> Result { + let mut tokens = Vec::new(); + let mut ids = Vec::new(); + let mut frequencies = Vec::new(); + for i in 0..reader.num_batches().await { + let batch = reader.read_record_batch(i).await?; + let token_col = batch[TOKEN_COL].as_string::(); + let token_id_col = batch[TOKEN_ID_COL].as_primitive::(); + let frequency_col = batch[FREQUENCY_COL].as_primitive::(); + + tokens.extend(token_col.iter().map(|v| v.unwrap().to_owned())); + ids.extend(token_id_col.iter().map(|v| v.unwrap())); + frequencies.extend(frequency_col.iter().map(|v| v.unwrap())); + } + + Ok(Self { + tokens, + ids, + frequencies, + }) + } + + fn add(&mut self, token: String) -> u32 { + let token_id = match self.get(&token) { + Some(token_id) => token_id, + None => self.next_id(), + }; + + // add token if it doesn't exist + if token_id == self.next_id() { + self.tokens.push(token); + self.ids.push(token_id); + self.frequencies.push(0); + } + + self.frequencies[token_id as usize] += 1; + token_id + } + + fn get(&self, token: &String) -> Option { + let pos = self.tokens.binary_search(token).ok()?; + Some(self.ids[pos]) + } + + fn next_id(&self) -> u32 { + self.ids.last().map(|id| id + 1).unwrap_or(0) + } +} + +// InvertedList is a mapping from token ids to row ids +// it's used to retrieve the documents that contain a token +#[derive(Debug, Clone, Default, DeepSizeOf)] +struct InvertedList { + tokens: Vec, + row_ids_list: Vec>, + frequencies_list: Vec>, +} + +impl InvertedList { + fn to_batch(&self) -> Result { + let token_id_col = UInt32Array::from(self.tokens.clone()); + let mut row_ids_col = + ListBuilder::with_capacity(UInt64Builder::new(), self.row_ids_list.len()); + let mut frequencies_col = + ListBuilder::with_capacity(UInt64Builder::new(), self.frequencies_list.len()); + + for row_ids in &self.row_ids_list { + let builder = row_ids_col.values(); + for row_id in row_ids { + builder.append_value(*row_id); + } + row_ids_col.append(true); + } + + for frequencies in &self.frequencies_list { + let builder = frequencies_col.values(); + for frequency in frequencies { + builder.append_value(*frequency); + } + frequencies_col.append(true); + } + + let schema = arrow_schema::Schema::new(vec![ + arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false), + arrow_schema::Field::new( + ROW_ID, + DataType::List(Field::new_list_field(DataType::UInt64, true).into()), + false, + ), + arrow_schema::Field::new( + FREQUENCY_COL, + DataType::List(Field::new_list_field(DataType::UInt64, true).into()), + false, + ), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(token_id_col) as ArrayRef, + Arc::new(row_ids_col.finish()) as ArrayRef, + Arc::new(frequencies_col.finish()) as ArrayRef, + ], + )?; + Ok(batch) + } + + async fn load(reader: Arc) -> Result { + let mut tokens = Vec::new(); + let mut row_ids_list = Vec::new(); + let mut frequencies_list = Vec::new(); + for i in 0..reader.num_batches().await { + let batch = reader.read_record_batch(i).await?; + let token_col = batch[TOKEN_ID_COL].as_primitive::(); + let row_ids_col = batch[ROW_ID].as_list::(); + let frequencies_col = batch[FREQUENCY_COL].as_list::(); + + tokens.extend(token_col.iter().map(|v| v.unwrap())); + for value in row_ids_col.iter() { + let value = value.unwrap(); + let row_ids = value + .as_primitive::() + .values() + .iter() + .cloned() + .collect_vec(); + row_ids_list.push(row_ids); + } + for value in frequencies_col.iter() { + let value = value.unwrap(); + let frequencies = value + .as_primitive::() + .values() + .iter() + .cloned() + .collect_vec(); + frequencies_list.push(frequencies); + } + } + + Ok(Self { + tokens, + row_ids_list, + frequencies_list, + }) + } + + fn add(&mut self, token_id: u32, row_id: u64) { + let pos = match self.tokens.binary_search(&token_id) { + Ok(pos) => pos, + Err(pos) => { + self.tokens.insert(pos, token_id); + self.row_ids_list.insert(pos, Vec::new()); + self.frequencies_list.insert(pos, Vec::new()); + pos + } + }; + + self.row_ids_list[pos].push(row_id); + self.frequencies_list[pos].push(1); + } + + fn retrieve(&self, token_id: u32) -> Option<(&[u64], &[u64])> { + let pos = self.tokens.binary_search(&token_id).ok()?; + Some((&self.row_ids_list[pos], &self.frequencies_list[pos])) + } +} + +// DocSet is a mapping from row ids to the number of tokens in the document +// It's used to sort the documents by the bm25 score +#[derive(Debug, Clone, Default, DeepSizeOf)] +struct DocSet { + row_ids: Vec, + num_tokens: Vec, + total_tokens: u64, +} + +impl DocSet { + fn average_length(&self) -> f32 { + self.total_tokens as f32 / self.row_ids.len() as f32 + } + + fn to_batch(&self) -> Result { + let row_id_col = UInt64Array::from(self.row_ids.clone()); + let num_tokens_col = UInt32Array::from(self.num_tokens.clone()); + + let schema = arrow_schema::Schema::new(vec![ + arrow_schema::Field::new(ROW_ID, DataType::UInt64, false), + arrow_schema::Field::new(NUM_TOKEN_COL, DataType::UInt32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(row_id_col) as ArrayRef, + Arc::new(num_tokens_col) as ArrayRef, + ], + )?; + Ok(batch) + } + + async fn load(reader: Arc) -> Result { + let mut row_ids = Vec::new(); + let mut num_tokens = Vec::new(); + let mut total_tokens = 0; + for i in 0..reader.num_batches().await { + let batch = reader.read_record_batch(i).await?; + let row_id_col = batch[ROW_ID].as_primitive::(); + let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::(); + + row_ids.extend(row_id_col.iter().map(|v| v.unwrap())); + num_tokens.extend(num_tokens_col.iter().map(|v| v.unwrap())); + total_tokens += num_tokens.iter().map(|v| *v as u64).sum::(); + } + + Ok(Self { + row_ids, + num_tokens, + total_tokens, + }) + } + + fn add(&mut self, row_id: u64, num_tokens: u32) { + self.row_ids.push(row_id); + self.num_tokens.push(num_tokens); + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{ArrayRef, RecordBatch, StringArray, UInt64Array}; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use futures::stream; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + + use crate::scalar::lance_format::LanceIndexStore; + use crate::scalar::ScalarIndex; + + #[tokio::test] + async fn test_inverted_index() { + let tempdir = tempfile::tempdir().unwrap(); + let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); + let store = LanceIndexStore::new(ObjectStore::local(), index_dir, None); + + let invert_index = super::InvertedIndex::default(); + let row_id_col = UInt64Array::from(vec![0, 1, 2, 3]); + let doc_col = StringArray::from(vec!["a b c", "a b", "a c", "b c"]); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("doc", arrow_schema::DataType::Utf8, false), + arrow_schema::Field::new(super::ROW_ID, arrow_schema::DataType::UInt64, false), + ]) + .into(), + vec![ + Arc::new(doc_col) as ArrayRef, + Arc::new(row_id_col) as ArrayRef, + ], + ) + .unwrap(); + let stream = RecordBatchStreamAdapter::new(batch.schema(), stream::iter(vec![Ok(batch)])); + let stream = Box::pin(stream); + + invert_index + .update(stream, &store) + .await + .expect("failed to update invert index"); + + let invert_index = super::InvertedIndex::load(Arc::new(store)).await.unwrap(); + let row_ids = invert_index + .search(&super::ScalarQuery::FullTextSearch(vec!["a".to_string()])) + .await + .unwrap(); + assert_eq!(row_ids.len(), 3); + assert!(row_ids.values().contains(&0)); + assert!(row_ids.values().contains(&1)); + assert!(row_ids.values().contains(&2)); + + let row_ids = invert_index + .search(&super::ScalarQuery::FullTextSearch(vec!["b".to_string()])) + .await + .unwrap(); + assert_eq!(row_ids.len(), 3); + assert!(row_ids.values().contains(&0)); + assert!(row_ids.values().contains(&1)); + assert!(row_ids.values().contains(&3)); + } +} diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index 5775df9be0..4aa222484f 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -155,9 +155,10 @@ impl IndexStore for LanceIndexStore { #[cfg(test)] mod tests { - use std::{ops::Bound, path::Path}; + use std::{collections::HashMap, ops::Bound, path::Path}; use crate::scalar::{ + bitmap::{train_bitmap_index, BitmapIndex}, btree::{train_btree_index, BTreeIndex, BtreeTrainingSource}, flat::FlatIndexMetadata, ScalarIndex, ScalarQuery, @@ -167,8 +168,9 @@ mod tests { use arrow_array::{ cast::AsArray, types::{Float32Type, Int32Type, UInt64Type}, - RecordBatchIterator, RecordBatchReader, UInt64Array, + RecordBatchIterator, RecordBatchReader, StringArray, UInt64Array, }; + use arrow_schema::Schema as ArrowSchema; use arrow_schema::{DataType, Field, TimeUnit}; use arrow_select::take::TakeOptions; use datafusion::physical_plan::SendableRecordBatchStream; @@ -654,4 +656,427 @@ mod tests { let row_ids = index.search(&ScalarQuery::IsNull()).await.unwrap(); assert_eq!(row_ids.len(), 4096); } + + async fn train_bitmap( + index_store: &Arc, + data: impl RecordBatchReader + Send + Sync + 'static, + ) { + let data = Box::new(MockTrainingSource::new(data).await); + train_bitmap_index(data, index_store.as_ref()) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_bitmap_working() { + let tempdir = tempdir().unwrap(); + let index_store = test_store(&tempdir); + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("values", DataType::Utf8, true), + Field::new("row_ids", DataType::UInt64, false), + ])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec![Some("abcd"), None, Some("abcd")])), + Arc::new(UInt64Array::from(vec![1, 2, 3])), + ], + ) + .unwrap(); + + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec![ + Some("apple"), + Some("hello"), + Some("abcd"), + ])), + Arc::new(UInt64Array::from(vec![4, 5, 6])), + ], + ) + .unwrap(); + + let batches = vec![batch1, batch2]; + let data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + train_bitmap(&index_store, data).await; + + let index = BitmapIndex::load(index_store).await.unwrap(); + + let row_ids = index + .search(&ScalarQuery::Equals(ScalarValue::Utf8(None))) + .await + .unwrap(); + + assert_eq!(1, row_ids.len()); + assert_eq!(2, row_ids.values()[0]); + + let row_ids = index + .search(&ScalarQuery::Equals(ScalarValue::Utf8(Some( + "abcd".to_string(), + )))) + .await + .unwrap(); + + let expected = vec![1, 3, 6]; + let expected_arr = UInt64Array::from_iter_values(expected.into_iter()); + + assert_eq!(3, row_ids.len()); + assert_eq!(expected_arr, row_ids); + } + + #[tokio::test] + async fn test_basic_bitmap() { + let tempdir = tempdir().unwrap(); + let index_store = test_store(&tempdir); + let data = gen() + .col("values", array::step::()) + .col("row_ids", array::step::()) + .into_reader_rows(RowCount::from(4096), BatchCount::from(100)); + train_bitmap(&index_store, data).await; + let index = BitmapIndex::load(index_store).await.unwrap(); + + let row_ids = index + .search(&ScalarQuery::Equals(ScalarValue::Int32(Some(10000)))) + .await + .unwrap(); + + assert_eq!(1, row_ids.len()); + assert_eq!(10000, row_ids.values()[0]); + + let row_ids = index + .search(&ScalarQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::Int32(Some(-100))), + )) + .await + .unwrap(); + + assert_eq!(0, row_ids.len()); + + let row_ids = index + .search(&ScalarQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::Int32(Some(100))), + )) + .await + .unwrap(); + + assert_eq!(100, row_ids.len()); + } + + async fn check_bitmap(index: &BitmapIndex, query: ScalarQuery, expected: &[u64]) { + let results = index.search(&query).await.unwrap(); + let expected_arr = UInt64Array::from_iter_values(expected.iter().copied()); + assert_eq!(results, expected_arr); + } + + #[tokio::test] + async fn test_bitmap_with_gaps() { + let tempdir = tempdir().unwrap(); + let index_store = test_store(&tempdir); + let batch_one = gen() + .col("values", array::cycle::(vec![0, 1, 4, 5])) + .col("row_ids", array::cycle::(vec![0, 1, 2, 3])) + .into_batch_rows(RowCount::from(4)); + let batch_two = gen() + .col("values", array::cycle::(vec![10, 11, 11, 15])) + .col("row_ids", array::cycle::(vec![40, 50, 60, 70])) + .into_batch_rows(RowCount::from(4)); + let batch_three = gen() + .col("values", array::cycle::(vec![15, 15, 15, 15])) + .col( + "row_ids", + array::cycle::(vec![400, 500, 600, 700]), + ) + .into_batch_rows(RowCount::from(4)); + let batch_four = gen() + .col("values", array::cycle::(vec![15, 16, 20, 20])) + .col( + "row_ids", + array::cycle::(vec![4000, 5000, 6000, 7000]), + ) + .into_batch_rows(RowCount::from(4)); + let batches = vec![batch_one, batch_two, batch_three, batch_four]; + let schema = Arc::new(Schema::new(vec![ + Field::new("values", DataType::Int32, false), + Field::new("row_ids", DataType::UInt64, false), + ])); + let data = RecordBatchIterator::new(batches, schema); + train_bitmap(&index_store, data).await; + let index = BitmapIndex::load(index_store).await.unwrap(); + + // The above should create four pages + // + // 0 - 5 + // 10 - 15 + // 15 - 15 + // 15 - 20 + // + // This will help us test various indexing corner cases + + // No results (off the left side) + check_bitmap( + &index, + ScalarQuery::Equals(ScalarValue::Int32(Some(-3))), + &[], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Unbounded, + Bound::Included(ScalarValue::Int32(Some(-3))), + ), + &[], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Included(ScalarValue::Int32(Some(-10))), + Bound::Included(ScalarValue::Int32(Some(-3))), + ), + &[], + ) + .await; + + // Hitting the middle of a bucket + check_bitmap( + &index, + ScalarQuery::Equals(ScalarValue::Int32(Some(4))), + &[2], + ) + .await; + + // Hitting a gap between two buckets + check_bitmap( + &index, + ScalarQuery::Equals(ScalarValue::Int32(Some(7))), + &[], + ) + .await; + + // Hitting the lowest of the overlapping buckets + check_bitmap( + &index, + ScalarQuery::Equals(ScalarValue::Int32(Some(11))), + &[50, 60], + ) + .await; + + // Hitting the 15 shared on all three buckets + check_bitmap( + &index, + ScalarQuery::Equals(ScalarValue::Int32(Some(15))), + &[70, 400, 500, 600, 700, 4000], + ) + .await; + + // Hitting the upper part of the three overlapping buckets + check_bitmap( + &index, + ScalarQuery::Equals(ScalarValue::Int32(Some(20))), + &[6000, 7000], + ) + .await; + + // Ranges that capture multiple buckets + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Unbounded, + Bound::Included(ScalarValue::Int32(Some(11))), + ), + &[0, 1, 2, 3, 40, 50, 60], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::Int32(Some(11))), + ), + &[0, 1, 2, 3, 40], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Included(ScalarValue::Int32(Some(4))), + Bound::Unbounded, + ), + &[ + 2, 3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000, + ], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Included(ScalarValue::Int32(Some(4))), + Bound::Included(ScalarValue::Int32(Some(11))), + ), + &[2, 3, 40, 50, 60], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Included(ScalarValue::Int32(Some(4))), + Bound::Excluded(ScalarValue::Int32(Some(11))), + ), + &[2, 3, 40], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Excluded(ScalarValue::Int32(Some(4))), + Bound::Unbounded, + ), + &[ + 3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000, + ], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Excluded(ScalarValue::Int32(Some(4))), + Bound::Included(ScalarValue::Int32(Some(11))), + ), + &[3, 40, 50, 60], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Excluded(ScalarValue::Int32(Some(4))), + Bound::Excluded(ScalarValue::Int32(Some(11))), + ), + &[3, 40], + ) + .await; + + check_bitmap( + &index, + ScalarQuery::Range( + Bound::Excluded(ScalarValue::Int32(Some(-50))), + Bound::Excluded(ScalarValue::Int32(Some(1000))), + ), + &[ + 0, 1, 2, 3, 40, 50, 60, 70, 400, 500, 600, 700, 4000, 5000, 6000, 7000, + ], + ) + .await; + } + + #[tokio::test] + async fn test_bitmap_update() { + let index_dir = tempdir().unwrap(); + let index_store = test_store(&index_dir); + let data = gen() + .col("values", array::step::()) + .col("row_ids", array::step::()) + .into_reader_rows(RowCount::from(4096), BatchCount::from(1)); + train_bitmap(&index_store, data).await; + let index = BitmapIndex::load(index_store).await.unwrap(); + + let data = gen() + .col("values", array::step_custom::(4096, 1)) + .col("row_ids", array::step_custom::(4096, 1)) + .into_reader_rows(RowCount::from(4096), BatchCount::from(1)); + + let updated_index_dir = tempdir().unwrap(); + let updated_index_store = test_store(&updated_index_dir); + index + .update( + lance_datafusion::utils::reader_to_stream(Box::new(data)), + updated_index_store.as_ref(), + ) + .await + .unwrap(); + let updated_index = BitmapIndex::load(updated_index_store).await.unwrap(); + + let row_ids = updated_index + .search(&ScalarQuery::Equals(ScalarValue::Int32(Some(5000)))) + .await + .unwrap(); + + assert_eq!(1, row_ids.len()); + assert_eq!( + vec![5000], + row_ids.values().into_iter().copied().collect::>() + ); + } + + #[tokio::test] + async fn test_bitmap_remap() { + let index_dir = tempdir().unwrap(); + let index_store = test_store(&index_dir); + let data = gen() + .col("values", array::step::()) + .col("row_ids", array::step::()) + .into_reader_rows(RowCount::from(50), BatchCount::from(1)); + train_bitmap(&index_store, data).await; + let index = BitmapIndex::load(index_store).await.unwrap(); + + let mapping = (0..50) + .map(|i| { + let map_result = if i == 5 { + Some(65) + } else if i == 7 { + None + } else { + Some(i) + }; + (i, map_result) + }) + .collect::>(); + + let remapped_dir = tempdir().unwrap(); + let remapped_store = test_store(&remapped_dir); + index + .remap(&mapping, remapped_store.as_ref()) + .await + .unwrap(); + let remapped_index = BitmapIndex::load(remapped_store).await.unwrap(); + + // Remapped to new value + assert_eq!( + remapped_index + .search(&ScalarQuery::Equals(ScalarValue::Int32(Some(5)))) + .await + .unwrap() + .value(0), + 65 + ); + // Deleted + assert!(remapped_index + .search(&ScalarQuery::Equals(ScalarValue::Int32(Some(7)))) + .await + .unwrap() + .is_empty()); + // Not remapped + assert_eq!( + remapped_index + .search(&ScalarQuery::Equals(ScalarValue::Int32(Some(3)))) + .await + .unwrap() + .value(0), + 3 + ); + } } diff --git a/rust/lance-index/src/vector/ivf.rs b/rust/lance-index/src/vector/ivf.rs index 87fcdf164d..fe5b323e5e 100644 --- a/rust/lance-index/src/vector/ivf.rs +++ b/rust/lance-index/src/vector/ivf.rs @@ -150,7 +150,7 @@ impl IvfTransformer { centroids: FixedSizeListArray, distance_type: DistanceType, vector_column: &str, - pq: Arc, + pq: ProductQuantizer, range: Option>, ) -> Self { let mut transforms: Vec> = vec![]; diff --git a/rust/lance-index/src/vector/ivf/builder.rs b/rust/lance-index/src/vector/ivf/builder.rs index 89ba50c3d1..79d48eec17 100644 --- a/rust/lance-index/src/vector/ivf/builder.rs +++ b/rust/lance-index/src/vector/ivf/builder.rs @@ -47,6 +47,9 @@ pub struct IvfBuildParams { /// Use residual vectors to build sub-vector. pub use_residual: bool, + + /// Storage options used to load precomputed partitions. + pub storage_options: Option>, } impl Default for IvfBuildParams { @@ -61,6 +64,7 @@ impl Default for IvfBuildParams { shuffle_partition_batches: 1024 * 10, shuffle_partition_concurrency: 2, use_residual: true, + storage_options: None, } } } diff --git a/rust/lance-index/src/vector/ivf/shuffler.rs b/rust/lance-index/src/vector/ivf/shuffler.rs index e76387b359..545e08ed2b 100644 --- a/rust/lance-index/src/vector/ivf/shuffler.rs +++ b/rust/lance-index/src/vector/ivf/shuffler.rs @@ -88,8 +88,9 @@ pub async fn shuffle_dataset( shuffler } else { info!( - "Calculating IVF partitions for vectors (num_partitions={})", - num_partitions + "Calculating IVF partitions for vectors (num_partitions={}, precomputed_partitions={})", + num_partitions, + precomputed_partitions.is_some() ); let mut shuffler = IvfShuffler::try_new(num_partitions, None)?; diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 9eee17c681..9f8938c029 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -4,20 +4,22 @@ //! Product Quantization //! -use std::any::Any; use std::sync::Arc; +use arrow::datatypes::{self, ArrowPrimitiveType}; use arrow_array::{cast::AsArray, Array, FixedSizeListArray, UInt8Array}; -use arrow_array::{ArrayRef, Float32Array}; +use arrow_array::{ArrayRef, Float32Array, PrimitiveArray}; +use arrow_schema::DataType; use deepsize::DeepSizeOf; use lance_arrow::*; use lance_core::{Error, Result}; -use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, DistanceType, Dot, L2}; -use lance_linalg::kernels::argmin_value_float; +use lance_linalg::distance::{dot_distance_batch, DistanceType, Dot, L2}; use lance_linalg::kmeans::compute_partition; -use lance_linalg::{distance::MetricType, MatrixView}; +use num_traits::Float; +use prost::Message; use rayon::prelude::*; use snafu::{location, Location}; +use storage::{ProductQuantizationMetadata, ProductQuantizationStorage, PQ_METADTA_KEY}; pub mod builder; mod distance; @@ -27,240 +29,132 @@ pub(crate) mod utils; use self::distance::{build_distance_table_l2, compute_l2_distance}; pub use self::utils::num_centroids; -use super::pb; +use super::quantizer::{Quantization, QuantizationMetadata, QuantizationType, Quantizer}; +use super::{pb, PQ_CODE_COLUMN}; pub use builder::PQBuildParams; use utils::get_sub_vector_centroids; -/// Product Quantization -pub trait ProductQuantizer: Send + Sync + DeepSizeOf + std::fmt::Debug { - fn as_any(&self) -> &dyn Any; - - /// Compute the distance between query vector to the PQ code. - /// - fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result; - - fn transform(&self, data: &dyn Array) -> Result; - - /// Number of sub-vectors - fn num_sub_vectors(&self) -> usize; - - fn num_bits(&self) -> u32; - - fn dimension(&self) -> usize; - - // TODO: move to pub(crate) once the refactor of lance::index to lance-index is done. - fn codebook_as_fsl(&self) -> FixedSizeListArray; - - /// Whether to use residual as input or not. - fn use_residual(&self) -> bool; -} - -/// Product Quantization, optimized for [Apache Arrow] buffer memory layout. -/// -// -// TODO: move this to be pub(crate) once we have a better way to test it. #[derive(Debug, Clone)] -pub struct ProductQuantizerImpl -where - T::Native: Dot + L2, -{ - /// Number of bits for the centroids. - /// - /// Only support 8, as one of `u8` byte now. - pub num_bits: u32, - - /// Number of sub-vectors. +pub struct ProductQuantizer { pub num_sub_vectors: usize, - - /// Vector dimension. + pub num_bits: u32, pub dimension: usize, - - /// Distance type. - pub metric_type: MetricType, - - /// PQ codebook - /// - /// ```((2 ^ nbits) * num_subvector * sub_vector_length)``` of `f32` - /// - /// Use a layout that is cache / SIMD friendly to compute centroid. - /// But not sure how to make distance lookup via PQ code lookup - /// be cache friendly tho. - /// - /// Layout: - /// - /// - *row*: all centroids for the same sub-vector. - /// - *column*: the centroid value of the n-th sub-vector. - /// - /// ```text - /// // Centroids for a sub-vector. - /// Codebook[sub_vector_id][pq_code] - /// ``` - pub codebook: Arc, + pub codebook: FixedSizeListArray, + pub distance_type: DistanceType, } -impl DeepSizeOf for ProductQuantizerImpl -where - T::Native: Dot + L2, -{ +impl DeepSizeOf for ProductQuantizer { fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { self.codebook.get_array_memory_size() + + self.num_sub_vectors.deep_size_of_children(_context) + + self.num_bits.deep_size_of_children(_context) + + self.dimension.deep_size_of_children(_context) + + self.distance_type.deep_size_of_children(_context) } } -impl ProductQuantizerImpl -where - T::Native: Dot + L2, -{ - /// Create a [`ProductQuantizer`] with pre-trained codebook. +impl ProductQuantizer { pub fn new( - m: usize, - nbits: u32, + num_sub_vectors: usize, + num_bits: u32, dimension: usize, - codebook: Arc, - metric_type: MetricType, + codebook: FixedSizeListArray, + distance_type: DistanceType, ) -> Self { - assert_ne!( - metric_type, - MetricType::Cosine, - "Product quantization does not support cosine, use normalized L2 instead" - ); - assert_eq!(nbits, 8, "nbits can only be 8"); Self { - num_bits: nbits, - num_sub_vectors: m, + num_bits, + num_sub_vectors, dimension, codebook, - metric_type, + distance_type, } } - pub fn num_centroids(num_bits: u32) -> usize { - 2_usize.pow(num_bits) + pub fn from_proto(proto: &pb::Pq, distance_type: DistanceType) -> Result { + let codebook = match proto.codebook_tensor.as_ref() { + Some(tensor) => FixedSizeListArray::try_from(tensor)?, + None => FixedSizeListArray::try_new_from_values( + Float32Array::from(proto.codebook.clone()), + proto.dimension as i32, + )?, + }; + Ok(Self { + num_bits: proto.num_bits, + num_sub_vectors: proto.num_sub_vectors as usize, + dimension: proto.dimension as usize, + codebook, + distance_type, + }) } - /// Calculate codebook length. - pub fn codebook_length(num_bits: u32, num_sub_vectors: usize) -> usize { - Self::num_centroids(num_bits) * num_sub_vectors + pub fn use_residual(&self) -> bool { + matches!(self.distance_type, DistanceType::L2 | DistanceType::Cosine) } - /// Get the centroids for one sub-vector. - /// - /// Returns a flatten `num_centroids * sub_vector_width` f32 array. - pub fn centroids(&self, sub_vector_idx: usize) -> &[T::Native] { - get_sub_vector_centroids( - self.codebook.as_slice(), - self.dimension, - self.num_bits, - self.num_sub_vectors, - sub_vector_idx, - ) - } + fn transform(&self, vectors: &dyn Array) -> Result + where + T::Native: Float + L2 + Dot, + { + let fsl = vectors.as_fixed_size_list_opt().ok_or(Error::Index { + message: format!( + "Expect to be a FixedSizeList vector array, got: {:?} array", + vectors.data_type() + ), + location: location!(), + })?; + let num_sub_vectors = self.num_sub_vectors; + let dim = self.dimension; + let num_bits = self.num_bits; + let codebook = self.codebook.values().as_primitive::(); - /// Reconstruct a vector from its PQ code. - /// - /// It only supports U8 PQ code for now. - #[allow(dead_code)] - pub(crate) fn reconstruct(&self, code: &[u8]) -> Arc { - assert_eq!(code.len(), self.num_sub_vectors); - let mut builder = Vec::with_capacity(self.dimension); - let sub_vector_dim = self.dimension / self.num_sub_vectors; - for (i, sub_code) in code.iter().enumerate() { - let centroids = self.centroids(i); - builder.extend_from_slice( - ¢roids[*sub_code as usize * sub_vector_dim - ..(*sub_code as usize + 1) * sub_vector_dim], - ); - } - Arc::new(T::ArrayType::from(builder)) - } + let distance_type = self.distance_type; - /// Compute the quantization distortion (E). - /// - /// Quantization distortion is the difference between the centroids - /// from the PQ code to the actual vector. - /// - /// This method is just for debugging purpose. - #[allow(dead_code)] - pub(crate) fn distortion( - &self, - data: &MatrixView, - distance_type: DistanceType, - ) -> Result { - let sub_vector_width = self.dimension / self.num_sub_vectors; - let total_distortion = data - .iter() + let flatten_data = fsl.values().as_primitive::(); + let sub_dim = dim / num_sub_vectors; + let values = flatten_data + .values() + .par_chunks(dim) .map(|vector| { vector - .chunks_exact(sub_vector_width) + .chunks_exact(sub_dim) .enumerate() - .map(|(sub_vector_idx, sub_vec)| { - let centroids = self.centroids(sub_vector_idx); - let distances = match distance_type { - DistanceType::L2 => { - l2_distance_batch(sub_vec, centroids, sub_vector_width) - } - DistanceType::Dot => { - dot_distance_batch(sub_vec, centroids, sub_vector_width) - } - _ => { - panic!( - "ProductQuantization: distance type {} is not supported", - distance_type - ); - } - }; - argmin_value_float(distances).map(|(_, v)| v).unwrap_or(0.0) + .map(|(sub_idx, sub_vector)| { + let centroids = get_sub_vector_centroids( + codebook.values(), + dim, + num_bits, + num_sub_vectors, + sub_idx, + ); + compute_partition(centroids, sub_vector, distance_type).map(|v| v as u8) }) - .sum::() as f64 + .collect::>() }) - .sum::(); - Ok(total_distortion / data.num_rows() as f64) - } + .flatten() + .collect::>(); - fn build_l2_distance_table(&self, key: &dyn Array) -> Result> { - let key: &T::ArrayType = key.as_any().downcast_ref().ok_or(Error::Index { - message: format!( - "Build L2 distance table, type mismatch: {}", - key.data_type() - ), - location: Default::default(), - })?; - Ok(build_distance_table_l2( - self.codebook.as_slice(), - self.num_bits, - self.num_sub_vectors, - key.as_slice(), - )) + Ok(Arc::new(FixedSizeListArray::try_new_from_values( + UInt8Array::from(values), + self.num_sub_vectors as i32, + )?)) } - /// Compute L2 distance from the query to all code. - /// - /// Type parameters - /// --------------- - /// - C: the tile size of code-book to run at once. - /// - V: the tile size of PQ code to run at once. - /// - /// Parameters - /// ---------- - /// - distance_table: the pre-computed L2 distance table. - /// It is a flatten array of [num_sub_vectors, num_centroids] f32. - /// - code: the PQ code to be used to compute the distances. - /// - /// Returns - /// ------- - /// The squared L2 distance. - #[inline] - fn compute_l2_distance( - &self, - distance_table: &[f32], - code: &[u8], - ) -> Float32Array { - Float32Array::from(compute_l2_distance::( - distance_table, - self.num_bits, - self.num_sub_vectors, - code, - )) + pub fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result { + match self.distance_type { + DistanceType::L2 => self.l2_distances(query, code), + DistanceType::Cosine => { + // L2 over normalized vectors: ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy) + // Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy + // Therefore, Cosine = L2 / 2 + let l2_dists = self.l2_distances(query, code)?; + Ok(l2_dists.values().iter().map(|v| *v / 2.0).collect()) + } + DistanceType::Dot => self.dot_distances(query, code), + _ => panic!( + "ProductQuantization: distance type {} not supported", + self.distance_type + ), + } } /// Pre-compute L2 distance from the query to all code. @@ -285,24 +179,40 @@ where /// - code: the PQ code in one partition. /// fn dot_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result { - let key: &T::ArrayType = key.as_any().downcast_ref().ok_or(Error::Index { - message: format!( - "Build Dot distance table, type mismatch: {}", - key.data_type() - ), - location: Default::default(), - })?; + match key.data_type() { + DataType::Float16 => { + self.dot_distances_impl::(key.as_primitive(), code) + } + DataType::Float32 => { + self.dot_distances_impl::(key.as_primitive(), code) + } + DataType::Float64 => { + self.dot_distances_impl::(key.as_primitive(), code) + } + _ => Err(Error::Index { + message: format!("unsupported data type: {}", key.data_type()), + location: location!(), + }), + } + } - // Distance table: `[f32: num_sub_vectors(row) * num_centroids(column)]`. + fn dot_distances_impl( + &self, + key: &PrimitiveArray, + code: &UInt8Array, + ) -> Result + where + T::Native: Dot, + { let capacity = self.num_sub_vectors * num_centroids(self.num_bits); let mut distance_table = Vec::with_capacity(capacity); let sub_vector_length = self.dimension / self.num_sub_vectors; - key.as_slice() + key.values() .chunks_exact(sub_vector_length) .enumerate() .for_each(|(sub_vec_id, sub_vec)| { - let subvec_centroids = self.centroids(sub_vec_id); + let subvec_centroids = self.centroids::(sub_vec_id); let distances = dot_distance_batch(sub_vec, subvec_centroids, sub_vector_length); distance_table.extend(distances); }); @@ -319,130 +229,169 @@ where }), )) } -} - -impl ProductQuantizer for ProductQuantizerImpl -where - T::Native: Dot + L2, -{ - fn as_any(&self) -> &dyn Any { - self - } - fn transform(&self, data: &dyn Array) -> Result { - let fsl = data - .as_fixed_size_list_opt() - .ok_or(Error::Index { - message: format!( - "Expect to be a FixedSizeList vector array, got: {:?} array", - data.data_type() - ), + fn build_l2_distance_table(&self, key: &dyn Array) -> Result> { + match key.data_type() { + DataType::Float16 => { + Ok(self.build_l2_distance_table_impl::(key.as_primitive())) + } + DataType::Float32 => { + Ok(self.build_l2_distance_table_impl::(key.as_primitive())) + } + DataType::Float64 => { + Ok(self.build_l2_distance_table_impl::(key.as_primitive())) + } + _ => Err(Error::Index { + message: format!("unsupported data type: {}", key.data_type()), location: location!(), - })? - .clone(); + }), + } + } - let num_sub_vectors = self.num_sub_vectors; - let dim = self.dimension; - let num_bits = self.num_bits; - let codebook = self.codebook.clone(); - - let distance_type = self.metric_type; - - let flatten_data = - fsl.values() - .as_any() - .downcast_ref::() - .ok_or(Error::Index { - message: format!( - "Expect to be a float vector array, got: {:?}", - fsl.value_type() - ), - location: location!(), - })?; + fn build_l2_distance_table_impl( + &self, + key: &PrimitiveArray, + ) -> Vec + where + T::Native: L2, + { + build_distance_table_l2( + self.codebook.values().as_primitive::().values(), + self.num_bits, + self.num_sub_vectors, + key.values(), + ) + } - let sub_dim = dim / num_sub_vectors; - let values = flatten_data - .as_slice() - .par_chunks(dim) - .map(|vector| { - vector - .chunks_exact(sub_dim) - .enumerate() - .map(|(sub_idx, sub_vector)| { - let centroids = get_sub_vector_centroids( - codebook.as_slice(), - dim, - num_bits, - num_sub_vectors, - sub_idx, - ); - compute_partition(centroids, sub_vector, distance_type).map(|v| v as u8) - }) - .collect::>() - }) - .flatten() - .collect::>(); + /// Compute L2 distance from the query to all code. + /// + /// Type parameters + /// --------------- + /// - C: the tile size of code-book to run at once. + /// - V: the tile size of PQ code to run at once. + /// + /// Parameters + /// ---------- + /// - distance_table: the pre-computed L2 distance table. + /// It is a flatten array of [num_sub_vectors, num_centroids] f32. + /// - code: the PQ code to be used to compute the distances. + /// + /// Returns + /// ------- + /// The squared L2 distance. + #[inline] + fn compute_l2_distance( + &self, + distance_table: &[f32], + code: &[u8], + ) -> Float32Array { + Float32Array::from(compute_l2_distance::( + distance_table, + self.num_bits, + self.num_sub_vectors, + code, + )) + } - Ok(Arc::new(FixedSizeListArray::try_new_from_values( - UInt8Array::from(values), - self.num_sub_vectors as i32, - )?)) + /// Get the centroids for one sub-vector. + /// + /// Returns a flatten `num_centroids * sub_vector_width` f32 array. + pub fn centroids(&self, sub_vector_idx: usize) -> &[T::Native] { + get_sub_vector_centroids( + self.codebook.values().as_primitive::().values(), + self.dimension, + self.num_bits, + self.num_sub_vectors, + sub_vector_idx, + ) } +} - fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result { - match self.metric_type { - DistanceType::L2 => self.l2_distances(query, code), - DistanceType::Cosine => { - // L2 over normalized vectors: ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy) - // Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy - // Therefore, Cosine = L2 / 2 - let l2_dists = self.l2_distances(query, code)?; - Ok(l2_dists.values().iter().map(|v| *v / 2.0).collect()) - } - DistanceType::Dot => self.dot_distances(query, code), - _ => panic!( - "ProductQuantization: metric type {} not supported", - self.metric_type - ), - } +impl Quantization for ProductQuantizer { + type BuildParams = PQBuildParams; + type Metadata = ProductQuantizationMetadata; + type Storage = ProductQuantizationStorage; + + fn build(_: &dyn Array, _: DistanceType, _: &Self::BuildParams) -> Result { + unimplemented!("ProductQuantizer cannot be built with new index builder") } - fn num_sub_vectors(&self) -> usize { + fn code_dim(&self) -> usize { self.num_sub_vectors } - fn num_bits(&self) -> u32 { - self.num_bits + fn column(&self) -> &'static str { + PQ_CODE_COLUMN + } + + fn quantize(&self, vectors: &dyn Array) -> Result { + let fsl = vectors.as_fixed_size_list_opt().ok_or(Error::Index { + message: format!( + "Expect to be a FixedSizeList vector array, got: {:?} array", + vectors.data_type() + ), + location: location!(), + })?; + + match fsl.value_type() { + DataType::Float16 => self.transform::(vectors), + DataType::Float32 => self.transform::(vectors), + DataType::Float64 => self.transform::(vectors), + _ => Err(Error::Index { + message: format!("unsupported data type: {}", fsl.value_type()), + location: location!(), + }), + } + } + + fn metadata_key() -> &'static str { + PQ_METADTA_KEY } - fn dimension(&self) -> usize { - self.dimension + fn quantization_type() -> QuantizationType { + QuantizationType::Product } - fn codebook_as_fsl(&self) -> FixedSizeListArray { - FixedSizeListArray::try_new_from_values( - self.codebook.as_ref().clone(), - self.dimension as i32, - ) - .unwrap() + fn metadata(&self, args: Option) -> Result { + let codebook_position = match args { + Some(args) => args.codebook_position, + None => Some(0), + }; + let codebook_position = codebook_position.ok_or(Error::Index { + message: "codebook_position not found".to_owned(), + location: location!(), + })?; + let tensor = pb::Tensor::try_from(&self.codebook)?; + Ok(serde_json::to_value(ProductQuantizationMetadata { + codebook_position, + num_bits: self.num_bits, + num_sub_vectors: self.num_sub_vectors, + dimension: self.dimension, + codebook: None, + codebook_tensor: tensor.encode_to_vec(), + })?) } - fn use_residual(&self) -> bool { - matches!(self.metric_type, DistanceType::L2 | DistanceType::Cosine) + fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result { + Ok(Quantizer::Product(Self::new( + metadata.num_sub_vectors, + metadata.num_bits, + metadata.dimension, + metadata.codebook.as_ref().unwrap().clone(), + distance_type, + ))) } } -#[allow(clippy::fallible_impl_from)] -impl TryFrom<&dyn ProductQuantizer> for pb::Pq { +impl TryFrom<&ProductQuantizer> for pb::Pq { type Error = Error; - fn try_from(pq: &dyn ProductQuantizer) -> Result { - let fsl = pq.codebook_as_fsl(); - let tensor = pb::Tensor::try_from(&fsl)?; + fn try_from(pq: &ProductQuantizer) -> Result { + let tensor = pb::Tensor::try_from(&pq.codebook)?; Ok(Self { - num_bits: pq.num_bits(), - num_sub_vectors: pq.num_sub_vectors() as u32, - dimension: pq.dimension() as u32, + num_bits: pq.num_bits, + num_sub_vectors: pq.num_sub_vectors as u32, + dimension: pq.dimension as u32, codebook: vec![], codebook_tensor: Some(tensor), }) @@ -457,27 +406,27 @@ mod tests { use approx::assert_relative_eq; use arrow::datatypes::UInt8Type; - use arrow_array::{ - types::{Float16Type, Float32Type}, - Float16Array, - }; + use arrow_array::Float16Array; use half::f16; + use lance_linalg::distance::l2_distance_batch; use lance_linalg::kernels::argmin; use lance_testing::datagen::generate_random_array; use num_traits::Zero; #[test] fn test_f16_pq_to_protobuf() { - let pq = ProductQuantizerImpl:: { - num_bits: 8, - num_sub_vectors: 4, - dimension: 16, - codebook: Arc::new(Float16Array::from_iter_values( - repeat(f16::zero()).take(256 * 16), - )), - metric_type: DistanceType::L2, - }; - let proto: pb::Pq = pb::Pq::try_from(&pq as &dyn ProductQuantizer).unwrap(); + let pq = ProductQuantizer::new( + 4, + 8, + 16, + FixedSizeListArray::try_new_from_values( + Float16Array::from_iter_values(repeat(f16::zero()).take(256 * 16)), + 16, + ) + .unwrap(), + DistanceType::L2, + ); + let proto: pb::Pq = pb::Pq::try_from(&pq).unwrap(); assert_eq!(proto.num_bits, 8); assert_eq!(proto.num_sub_vectors, 4); assert_eq!(proto.dimension, 16); @@ -493,14 +442,14 @@ mod tests { fn test_l2_distance() { const DIM: usize = 512; const TOTAL: usize = 66; // 64 + 2 to make sure reminder is handled correctly. - let codebook = Arc::new(generate_random_array(256 * DIM)); - let pq = ProductQuantizerImpl:: { - num_bits: 8, - num_sub_vectors: 16, - dimension: DIM, - codebook: codebook.clone(), - metric_type: DistanceType::L2, - }; + let codebook = generate_random_array(256 * DIM); + let pq = ProductQuantizer::new( + 16, + 8, + DIM, + FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(), + DistanceType::L2, + ); let pq_code = UInt8Array::from_iter_values((0..16 * TOTAL).map(|v| v as u8)); let query = generate_random_array(DIM); @@ -514,7 +463,7 @@ mod tests { code.iter() .enumerate() .flat_map(|(sub_idx, c)| { - let subvec_centroids = pq.centroids(sub_idx); + let subvec_centroids = pq.centroids::(sub_idx); let subvec = &query.values()[sub_idx * sub_vec_len..(sub_idx + 1) * sub_vec_len]; l2_distance_batch( @@ -541,24 +490,24 @@ mod tests { const DIM: usize = 16; const TOTAL: usize = 64; let codebook = generate_random_array(DIM * 256); - let pq = ProductQuantizerImpl:: { - num_bits: 8, - num_sub_vectors: 4, - dimension: DIM, - codebook: Arc::new(codebook), - metric_type: MetricType::L2, - }; + let pq = ProductQuantizer::new( + 4, + 8, + DIM, + FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(), + DistanceType::L2, + ); let vectors = generate_random_array(DIM * TOTAL); let fsl = FixedSizeListArray::try_new_from_values(vectors.clone(), DIM as i32).unwrap(); - let pq_code = pq.transform(&fsl).unwrap(); + let pq_code = pq.quantize(&fsl).unwrap(); let mut expected = Vec::with_capacity(TOTAL * 4); vectors.values().chunks_exact(DIM).for_each(|vec| { vec.chunks_exact(DIM / 4) .enumerate() .for_each(|(sub_idx, sub_vec)| { - let centroids = pq.centroids(sub_idx); + let centroids = pq.centroids::(sub_idx); let dists = l2_distance_batch(sub_vec, centroids, DIM / 4); let code = argmin(dists).unwrap() as u8; expected.push(code); diff --git a/rust/lance-index/src/vector/pq/builder.rs b/rust/lance-index/src/vector/pq/builder.rs index 0d6f308ac1..a0b118d1be 100644 --- a/rust/lance-index/src/vector/pq/builder.rs +++ b/rust/lance-index/src/vector/pq/builder.rs @@ -4,18 +4,14 @@ //! Product Quantizer Builder //! -use std::sync::Arc; - -use crate::pb; +use crate::vector::quantizer::QuantizerBuildParams; use arrow::datatypes::ArrowPrimitiveType; use arrow_array::types::{Float16Type, Float64Type}; -use arrow_array::{ - cast::AsArray, types::Float32Type, Array, ArrayRef, Float32Array, PrimitiveArray, -}; -use arrow_array::{ArrowNumericType, FixedSizeListArray}; +use arrow_array::FixedSizeListArray; +use arrow_array::{cast::AsArray, types::Float32Type, Array, ArrayRef}; use arrow_schema::DataType; use futures::{stream, StreamExt, TryStreamExt}; -use lance_arrow::{ArrowFloatType, FloatArray}; +use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray}; use lance_core::{Error, Result}; use lance_linalg::distance::{Dot, Normalize, L2}; use lance_linalg::{distance::MetricType, MatrixView}; @@ -24,8 +20,7 @@ use snafu::{location, Location}; use super::utils::divide_to_subvectors; use super::ProductQuantizer; -use crate::pb::Pq; -use crate::vector::{kmeans::train_kmeans, pq::ProductQuantizerImpl}; +use crate::vector::kmeans::train_kmeans; /// Parameters for building product quantizer. #[derive(Debug, Clone)] @@ -58,6 +53,12 @@ impl Default for PQBuildParams { } } +impl QuantizerBuildParams for PQBuildParams { + fn sample_size(&self) -> usize { + self.sample_rate * 2_usize.pow(self.num_bits as u32) + } +} + impl PQBuildParams { pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self { Self { @@ -80,7 +81,7 @@ impl PQBuildParams { &self, data: &MatrixView, metric_type: MetricType, - ) -> Result> + ) -> Result where ::Native: Dot + L2 + Normalize, { @@ -123,13 +124,13 @@ impl PQBuildParams { let pd_centroids = T::ArrayType::from(codebook_builder); - Ok(Arc::new(ProductQuantizerImpl::::new( + Ok(ProductQuantizer::new( self.num_sub_vectors, self.num_bits as u32, dimension, - Arc::new(pd_centroids), + FixedSizeListArray::try_new_from_values(pd_centroids, dimension as i32)?, metric_type, - ))) + )) } /// Build a [ProductQuantizer] from the given data. @@ -139,7 +140,7 @@ impl PQBuildParams { &self, data: &dyn Array, metric_type: MetricType, - ) -> Result> { + ) -> Result { assert_eq!(data.null_count(), 0); let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index { message: format!( @@ -169,62 +170,3 @@ impl PQBuildParams { } } } - -fn create_typed_pq> + ArrowNumericType>( - proto: &Pq, - metric_type: MetricType, - array: &dyn Array, -) -> Arc -where - ::Native: Dot + L2, -{ - Arc::new(ProductQuantizerImpl::::new( - proto.num_sub_vectors as usize, - proto.num_bits, - proto.dimension as usize, - Arc::new(array.as_primitive::().clone()), - metric_type, - )) -} - -/// Load ProductQuantizer from Protobuf -pub fn from_proto(proto: &Pq, metric_type: MetricType) -> Result> { - let mt = if metric_type == MetricType::Cosine { - MetricType::L2 - } else { - metric_type - }; - - if let Some(tensor) = &proto.codebook_tensor { - let fsl = FixedSizeListArray::try_from(tensor)?; - - match pb::tensor::DataType::try_from(tensor.data_type)? { - pb::tensor::DataType::Bfloat16 => { - unimplemented!() - } - pb::tensor::DataType::Float16 => { - Ok(create_typed_pq::(proto, mt, fsl.values())) - } - pb::tensor::DataType::Float32 => { - Ok(create_typed_pq::(proto, mt, fsl.values())) - } - pb::tensor::DataType::Float64 => { - Ok(create_typed_pq::(proto, mt, fsl.values())) - } - _ => Err(Error::Index { - message: format!("PQ builder: unsupported data type: {:?}", tensor.data_type), - location: location!(), - }), - } - } else { - Ok(Arc::new(ProductQuantizerImpl::::new( - proto.num_sub_vectors as usize, - proto.num_bits, - proto.dimension as usize, - Arc::new(Float32Array::from_iter_values( - proto.codebook.iter().copied(), - )), - metric_type, - ))) - } -} diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 4171fc40da..996721dec7 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -7,13 +7,14 @@ use std::{cmp::min, collections::HashMap, sync::Arc}; +use arrow::datatypes::{self}; use arrow_array::{ cast::AsArray, types::{Float32Type, UInt64Type, UInt8Type}, - FixedSizeListArray, Float32Array, RecordBatch, UInt64Array, UInt8Array, + FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array, }; use arrow_array::{Array, ArrayRef}; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use deepsize::DeepSizeOf; use lance_arrow::FixedSizeListArrayExt; @@ -24,6 +25,7 @@ use lance_io::{ traits::{WriteExt, Writer}, utils::read_message, }; +use lance_linalg::distance::L2; use lance_linalg::{distance::DistanceType, MatrixView}; use lance_table::{format::SelfDescribingFileReader, io::manifest::ManifestDescribing}; use object_store::path::Path; @@ -31,7 +33,8 @@ use prost::Message; use serde::{Deserialize, Serialize}; use snafu::{location, Location}; -use super::{distance::build_distance_table_l2, num_centroids, ProductQuantizerImpl}; +use super::ProductQuantizer; +use super::{distance::build_distance_table_l2, num_centroids}; use crate::vector::storage::STORAGE_METADATA_KEY; use crate::{ pb, @@ -105,7 +108,7 @@ impl QuantizerMetadata for ProductQuantizationMetadata { /// TODO: support f16/f64 later. #[derive(Clone, Debug)] pub struct ProductQuantizationStorage { - codebook: Arc, + codebook: FixedSizeListArray, batch: RecordBatch, // Metadata @@ -140,10 +143,9 @@ impl PartialEq for ProductQuantizationStorage { } } -#[allow(dead_code)] impl ProductQuantizationStorage { pub fn new( - codebook: Arc, + codebook: FixedSizeListArray, batch: RecordBatch, num_bits: u32, num_sub_vectors: usize, @@ -218,7 +220,7 @@ impl ProductQuantizationStorage { /// vector_col: &str /// The name of the column containing the vectors. pub async fn build( - quantizer: Arc>, + quantizer: ProductQuantizer, batch: &RecordBatch, vector_col: &str, ) -> Result { @@ -226,7 +228,7 @@ impl ProductQuantizationStorage { let num_bits = quantizer.num_bits; let dimension = quantizer.dimension; let num_sub_vectors = quantizer.num_sub_vectors; - let metric_type = quantizer.metric_type; + let metric_type = quantizer.distance_type; let transform = PQTransformer::new(quantizer, vector_col, PQ_CODE_COLUMN); let batch = transform.transform(batch)?; @@ -310,7 +312,7 @@ impl ProductQuantizationStorage { /// Write the PQ storage to disk. pub async fn write_full(&self, writer: &mut FileWriter) -> Result<()> { let pos = writer.object_writer.tell().await?; - let mat = MatrixView::::new(self.codebook.clone(), self.dimension); + let mat = MatrixView::::try_from(&self.codebook)?; let codebook_tensor = pb::Tensor::from(&mat); writer .object_writer @@ -362,18 +364,18 @@ impl QuantizerStorage for ProductQuantizationStorage { metadata: &Self::Metadata, ) -> Result { // Hard coded to float32 for now - let codebook = Arc::new( - metadata - .codebook - .as_ref() - .ok_or(Error::Index { - message: "Codebook not found in PQ metadata".to_string(), - location: location!(), - })? - .values() - .as_primitive::() - .clone(), - ); + let codebook = metadata + .codebook + .as_ref() + .ok_or(Error::Index { + message: "Codebook not found in PQ metadata".to_string(), + location: location!(), + })? + .values() + .as_primitive::() + .clone(); + let codebook = + FixedSizeListArray::try_new_from_values(codebook, metadata.dimension as i32)?; let schema = reader.schema(); let batch = reader.read_range(range, schema).await?; @@ -430,7 +432,7 @@ impl VectorStore for ProductQuantizationStorage { .clone(); Ok(Self { - codebook: Arc::new(codebook.values().as_primitive::().clone()), + codebook, batch, num_bits: metadata.num_bits, num_sub_vectors: metadata.num_sub_vectors, @@ -442,11 +444,7 @@ impl VectorStore for ProductQuantizationStorage { } fn to_batches(&self) -> Result> { - let codebook_fsl = FixedSizeListArray::try_new_from_values( - self.codebook.as_ref().clone(), - self.dimension as i32, - )?; - let codebook = pb::Tensor::try_from(&codebook_fsl)?.encode_to_vec(); + let codebook = pb::Tensor::try_from(&self.codebook)?.encode_to_vec(); let metadata = ProductQuantizationMetadata { codebook_position: 0, // deprecated in new format num_bits: self.num_bits, @@ -497,14 +495,42 @@ impl VectorStore for ProductQuantizationStorage { } fn dist_calculator(&self, query: ArrayRef) -> Self::DistanceCalculator<'_> { - PQDistCalculator::new( - self.codebook.values(), - self.num_bits, - self.num_sub_vectors, - self.pq_code.clone(), - query.as_primitive::().values(), - self.distance_type, - ) + match self.codebook.value_type() { + DataType::Float16 => PQDistCalculator::new( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.pq_code.clone(), + query.as_primitive::().values(), + self.distance_type, + ), + DataType::Float32 => PQDistCalculator::new( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.pq_code.clone(), + query.as_primitive::().values(), + self.distance_type, + ), + DataType::Float64 => PQDistCalculator::new( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.pq_code.clone(), + query.as_primitive::().values(), + self.distance_type, + ), + _ => unimplemented!("Unsupported data type: {:?}", self.codebook.value_type()), + } } fn dist_calculator_from_id(&self, _: u32) -> Self::DistanceCalculator<'_> { @@ -525,12 +551,12 @@ pub struct PQDistCalculator { } impl PQDistCalculator { - fn new( - codebook: &[f32], + fn new( + codebook: &[T], num_bits: u32, num_sub_vectors: usize, pq_code: Arc, - query: &[f32], + query: &[T], distance_type: DistanceType, ) -> Self { let distance_table = if matches!(distance_type, DistanceType::Cosine | DistanceType::L2) { @@ -566,8 +592,11 @@ impl DistCalculator for PQDistCalculator { #[cfg(test)] mod tests { + use crate::vector::storage::StorageBuilder; + use super::*; + use arrow_array::Float32Array; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use lance_arrow::FixedSizeListArrayExt; use lance_core::datatypes::Schema; @@ -578,16 +607,9 @@ mod tests { const NUM_SUB_VECTORS: usize = 16; async fn create_pq_storage() -> ProductQuantizationStorage { - let codebook = Arc::new(Float32Array::from_iter_values( - (0..256 * DIM).map(|v| v as f32), - )); - let pq = Arc::new(ProductQuantizerImpl::::new( - NUM_SUB_VECTORS, - 8, - DIM, - codebook, - DistanceType::L2, - )); + let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|v| v as f32)); + let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(); + let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::L2); let schema = ArrowSchema::new(vec![ Field::new( @@ -606,8 +628,8 @@ mod tests { let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl), Arc::new(row_ids)]).unwrap(); - ProductQuantizationStorage::build(pq.clone(), &batch, "vectors") - .await + StorageBuilder::new("vectors".to_owned(), pq.distance_type, pq) + .build(&batch) .unwrap() } @@ -616,7 +638,7 @@ mod tests { let storage = create_pq_storage().await; assert_eq!(storage.len(), TOTAL); assert_eq!(storage.num_sub_vectors, NUM_SUB_VECTORS); - assert_eq!(storage.codebook.len(), 256 * DIM); + assert_eq!(storage.codebook.values().len(), 256 * DIM); assert_eq!(storage.pq_code.len(), TOTAL * NUM_SUB_VECTORS); assert_eq!(storage.row_ids.len(), TOTAL); } diff --git a/rust/lance-index/src/vector/pq/transform.rs b/rust/lance-index/src/vector/pq/transform.rs index 7e7eb67f42..6bb3e2b5bc 100644 --- a/rust/lance-index/src/vector/pq/transform.rs +++ b/rust/lance-index/src/vector/pq/transform.rs @@ -11,23 +11,20 @@ use lance_core::{Error, Result}; use snafu::{location, Location}; use super::ProductQuantizer; +use crate::vector::quantizer::Quantization; use crate::vector::transform::Transformer; /// Product Quantizer Transformer /// /// It transforms a column of vectors into a column of PQ codes. pub struct PQTransformer { - quantizer: Arc, + quantizer: ProductQuantizer, input_column: String, output_column: String, } impl PQTransformer { - pub fn new( - quantizer: Arc, - input_column: &str, - output_column: &str, - ) -> Self { + pub fn new(quantizer: ProductQuantizer, input_column: &str, output_column: &str) -> Self { Self { quantizer, input_column: input_column.to_owned(), @@ -65,7 +62,7 @@ impl Transformer for PQTransformer { ), location: location!(), })?; - let pq_code = self.quantizer.transform(&data)?; + let pq_code = self.quantizer.quantize(&data)?; let pq_field = Field::new(&self.output_column, pq_code.data_type().clone(), false); let batch = batch.try_with_column(pq_field, Arc::new(pq_code))?; let batch = batch.drop_column(&self.input_column)?; diff --git a/rust/lance-index/src/vector/quantizer.rs b/rust/lance-index/src/vector/quantizer.rs index 83ea913d54..2f0b226abf 100644 --- a/rust/lance-index/src/vector/quantizer.rs +++ b/rust/lance-index/src/vector/quantizer.rs @@ -7,15 +7,14 @@ use std::sync::Arc; use arrow::array::AsArray; use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; -use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array}; +use arrow_array::{Array, ArrayRef, FixedSizeListArray}; use arrow_schema::DataType; use async_trait::async_trait; use deepsize::DeepSizeOf; -use lance_arrow::ArrowFloatType; use lance_core::{Error, Result}; use lance_file::reader::FileReader; use lance_io::traits::Reader; -use lance_linalg::distance::{DistanceType, Dot, L2}; +use lance_linalg::distance::DistanceType; use lance_table::format::SelfDescribingFileReader; use serde::{Deserialize, Serialize}; use snafu::{location, Location}; @@ -23,23 +22,18 @@ use snafu::{location, Location}; use crate::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; use super::flat::index::FlatQuantizer; -use super::pq::storage::PQ_METADTA_KEY; use super::pq::ProductQuantizer; use super::sq::builder::SQBuildParams; use super::sq::storage::SQ_METADATA_KEY; +use super::SQ_CODE_COLUMN; use super::{ ivf::storage::IvfModel, - pq::{ - storage::{ProductQuantizationMetadata, ProductQuantizationStorage}, - ProductQuantizerImpl, - }, sq::{ storage::{ScalarQuantizationMetadata, ScalarQuantizationStorage}, ScalarQuantizer, }, storage::VectorStore, }; -use super::{PQ_CODE_COLUMN, SQ_CODE_COLUMN}; pub trait Quantization: Send + Sync + Debug + DeepSizeOf + Into { type BuildParams: QuantizerBuildParams; @@ -94,7 +88,7 @@ impl QuantizerBuildParams for () { #[derive(Debug, Clone, DeepSizeOf)] pub enum Quantizer { Flat(FlatQuantizer), - Product(Arc), + Product(ProductQuantizer), Scalar(ScalarQuantizer), } @@ -102,8 +96,8 @@ impl Quantizer { pub fn code_dim(&self) -> usize { match self { Self::Flat(fq) => fq.code_dim(), - Self::Product(pq) => pq.num_sub_vectors(), - Self::Scalar(sq) => sq.dim, + Self::Product(pq) => pq.code_dim(), + Self::Scalar(sq) => sq.code_dim(), } } @@ -118,7 +112,7 @@ impl Quantizer { pub fn metadata_key(&self) -> &'static str { match self { Self::Flat(_) => FlatQuantizer::metadata_key(), - Self::Product(_) => ProductQuantizerImpl::::metadata_key(), + Self::Product(_) => ProductQuantizer::metadata_key(), Self::Scalar(_) => ScalarQuantizer::metadata_key(), } } @@ -140,17 +134,8 @@ impl Quantizer { } } -impl From> for Quantizer -where - T::Native: Dot + L2, -{ - fn from(pq: ProductQuantizerImpl) -> Self { - Self::Product(Arc::new(pq)) - } -} - -impl From> for Quantizer { - fn from(pq: Arc) -> Self { +impl From for Quantizer { + fn from(pq: ProductQuantizer) -> Self { Self::Product(pq) } } @@ -269,147 +254,6 @@ impl Quantization for ScalarQuantizer { } } -impl Quantization for Arc { - type BuildParams = (); - type Metadata = ProductQuantizationMetadata; - type Storage = ProductQuantizationStorage; - - fn build(_: &dyn Array, _: DistanceType, _: &Self::BuildParams) -> Result { - unimplemented!("ProductQuantizer cannot be built with new index builder") - } - - fn code_dim(&self) -> usize { - self.num_sub_vectors() - } - - fn column(&self) -> &'static str { - PQ_CODE_COLUMN - } - - fn quantize(&self, vectors: &dyn Array) -> Result { - let code_array = self.transform(vectors)?; - Ok(code_array) - } - - fn metadata_key() -> &'static str { - PQ_METADTA_KEY - } - - fn quantization_type() -> QuantizationType { - QuantizationType::Product - } - - fn metadata(&self, args: Option) -> Result { - let args = args.unwrap_or_default(); - - let codebook_position = args.codebook_position.ok_or(Error::Index { - message: "codebook_position not found".to_owned(), - location: location!(), - })?; - Ok(serde_json::to_value(ProductQuantizationMetadata { - codebook_position, - num_bits: self.num_bits(), - num_sub_vectors: self.num_sub_vectors(), - dimension: self.dimension(), - codebook: args.codebook, - codebook_tensor: Vec::new(), - })?) - } - - fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result { - Ok(Quantizer::Product(Arc::new(ProductQuantizerImpl::< - Float32Type, - >::new( - metadata.num_sub_vectors, - metadata.num_bits, - metadata.dimension, - Arc::new( - metadata - .codebook - .as_ref() - .unwrap() - .values() - .as_any() - .downcast_ref::() - .unwrap() - .clone(), - ), - distance_type, - )))) - } -} - -impl Quantization for ProductQuantizerImpl -where - T::Native: Dot + L2, -{ - type BuildParams = (); - type Metadata = ProductQuantizationMetadata; - type Storage = ProductQuantizationStorage; - - fn build(_: &dyn Array, _: DistanceType, _: &Self::BuildParams) -> Result { - unimplemented!("ProductQuantizer cannot be built with new index builder") - } - - fn code_dim(&self) -> usize { - self.num_sub_vectors() - } - - fn column(&self) -> &'static str { - PQ_CODE_COLUMN - } - - fn quantize(&self, vectors: &dyn Array) -> Result { - let code_array = self.transform(vectors)?; - Ok(code_array) - } - - fn metadata_key() -> &'static str { - PQ_METADTA_KEY - } - - fn quantization_type() -> QuantizationType { - QuantizationType::Product - } - - fn metadata(&self, args: Option) -> Result { - let args = args.unwrap_or_default(); - - let codebook_position = args.codebook_position.ok_or(Error::Index { - message: "codebook_position not found".to_owned(), - location: location!(), - })?; - Ok(serde_json::to_value(ProductQuantizationMetadata { - codebook_position, - num_bits: self.num_bits(), - num_sub_vectors: self.num_sub_vectors(), - dimension: self.dimension(), - codebook: args.codebook, - codebook_tensor: Vec::new(), - })?) - } - - fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result { - Ok(Quantizer::Product(Arc::new(Self::new( - metadata.num_sub_vectors, - metadata.num_bits, - metadata.dimension, - Arc::new( - metadata - .codebook - .as_ref() - .unwrap() - .values() - .as_any() - .downcast_ref::() - .unwrap() - .clone(), - ), - distance_type, - )))) - } -} - /// Loader to load partitioned [VectorStore] from disk. pub struct IvfQuantizationStorage { reader: FileReader, diff --git a/rust/lance-index/src/vector/v3/shuffler.rs b/rust/lance-index/src/vector/v3/shuffler.rs index 59ee0c93a7..4449746375 100644 --- a/rust/lance-index/src/vector/v3/shuffler.rs +++ b/rust/lance-index/src/vector/v3/shuffler.rs @@ -159,7 +159,6 @@ impl Shuffler for IvfShuffler { let writer = object_store.create(&part_path).await?; FileWriter::try_new( writer, - part_path.to_string(), lance_core::datatypes::Schema::try_from(schema.as_ref())?, Default::default(), ) diff --git a/rust/lance-io/Cargo.toml b/rust/lance-io/Cargo.toml index 2ee08948c0..75039fb965 100644 --- a/rust/lance-io/Cargo.toml +++ b/rust/lance-io/Cargo.toml @@ -48,13 +48,15 @@ async-priority-channel = "0.2.0" [dev-dependencies] criterion.workspace = true parquet.workspace = true -pprof.workspace = true tempfile.workspace = true mockall.workspace = true [build-dependencies] prost-build.workspace = true +[target.'cfg(target_os = "linux")'.dev-dependencies] +pprof.workspace = true + [[bench]] name = "scheduler" harness = false diff --git a/rust/lance-linalg/benches/norm_l2.rs b/rust/lance-linalg/benches/norm_l2.rs index 2528ae5db5..516abb727c 100644 --- a/rust/lance-linalg/benches/norm_l2.rs +++ b/rust/lance-linalg/benches/norm_l2.rs @@ -105,7 +105,7 @@ fn bench_distance(c: &mut Criterion) { run_bench::( c, target.as_slice(), - |vec| norm_l2_impl::(vec) as f32, + norm_l2_impl::, None, // TODO: implement SIMD for f64 ); } diff --git a/rust/lance-linalg/build.rs b/rust/lance-linalg/build.rs index 733ba4e8af..3d8d30a249 100644 --- a/rust/lance-linalg/build.rs +++ b/rust/lance-linalg/build.rs @@ -25,9 +25,10 @@ fn main() -> Result<(), String> { } if cfg!(target_os = "windows") { - return Err( - "cargo:warning=fp16 kernels are not supported on Windows. Please remove fp16kernels feature".to_string() + println!( + "cargo:warning=fp16 kernels are not supported on Windows. Skipping compilation of kernels." ); + return Ok(()); } if cfg!(all(target_arch = "aarch64", target_os = "macos")) { diff --git a/rust/lance-linalg/src/distance/cosine.rs b/rust/lance-linalg/src/distance/cosine.rs index e75e7dc0d0..efcb711bc9 100644 --- a/rust/lance-linalg/src/distance/cosine.rs +++ b/rust/lance-linalg/src/distance/cosine.rs @@ -65,6 +65,8 @@ pub trait Cosine: Dot + Normalize { } } +impl Cosine for u8 {} + impl Cosine for bf16 {} #[cfg(feature = "fp16kernels")] diff --git a/rust/lance-linalg/src/distance/hamming.rs b/rust/lance-linalg/src/distance/hamming.rs index 07a861544a..0b94f867bc 100644 --- a/rust/lance-linalg/src/distance/hamming.rs +++ b/rust/lance-linalg/src/distance/hamming.rs @@ -3,6 +3,11 @@ //! Hamming distance. +pub trait Hamming { + /// Hamming distance between two vectors. + fn hamming(x: &[u8], y: &[u8]) -> f32; +} + /// Hamming distance between two vectors. #[inline] pub fn hamming(x: &[u8], y: &[u8]) -> f32 { diff --git a/rust/lance-linalg/src/distance/l2.rs b/rust/lance-linalg/src/distance/l2.rs index e487518a62..706005c0c3 100644 --- a/rust/lance-linalg/src/distance/l2.rs +++ b/rust/lance-linalg/src/distance/l2.rs @@ -19,7 +19,7 @@ use lance_arrow::{ArrowFloatType, FloatArray}; #[cfg(feature = "fp16kernels")] use lance_core::utils::cpu::SimdSupport; use lance_core::utils::cpu::FP16_SIMD_SUPPORT; -use num_traits::{AsPrimitive, Float, Num}; +use num_traits::{AsPrimitive, Num}; use crate::simd::{ f32::{f32x16, f32x8}, @@ -64,7 +64,7 @@ pub fn l2_distance_uint_scalar(key: &[u8], target: &[u8]) -> f32 { #[inline] pub fn l2_scalar< T: AsPrimitive, - Output: Float + Sum + AddAssign + 'static, + Output: Num + Copy + Sum + AddAssign + 'static, const LANES: usize, >( from: &[T], @@ -98,6 +98,13 @@ pub fn l2_scalar< s + sums.iter().copied().sum() } +impl L2 for u8 { + #[inline] + fn l2(x: &[Self], y: &[Self]) -> f32 { + l2_distance_uint_scalar(x, y) + } +} + impl L2 for bf16 { #[inline] fn l2(x: &[Self], y: &[Self]) -> f32 { diff --git a/rust/lance-linalg/src/distance/norm_l2.rs b/rust/lance-linalg/src/distance/norm_l2.rs index 0efdbc2ab3..636b46514e 100644 --- a/rust/lance-linalg/src/distance/norm_l2.rs +++ b/rust/lance-linalg/src/distance/norm_l2.rs @@ -36,6 +36,13 @@ mod kernel { } } +impl Normalize for u8 { + #[inline] + fn norm_l2(vector: &[Self]) -> f32 { + norm_l2_impl::(vector) + } +} + impl Normalize for f16 { #[inline] fn norm_l2(vector: &[Self]) -> f32 { diff --git a/rust/lance-testing/src/datagen.rs b/rust/lance-testing/src/datagen.rs index 3547522d84..e4ba440b50 100644 --- a/rust/lance-testing/src/datagen.rs +++ b/rust/lance-testing/src/datagen.rs @@ -7,10 +7,14 @@ use std::collections::HashSet; use std::sync::Arc; use std::{iter::repeat_with, ops::Range}; -use arrow_array::{Float32Array, Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; +use arrow_array::types::ArrowPrimitiveType; +use arrow_array::{ + Float32Array, Int32Array, PrimitiveArray, RecordBatch, RecordBatchIterator, RecordBatchReader, +}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use lance_arrow::{fixed_size_list_type, ArrowFloatType, FixedSizeListArrayExt}; use num_traits::{real::Real, FromPrimitive}; +use rand::distributions::uniform::SampleUniform; use rand::{ distributions::Uniform, prelude::Distribution, rngs::StdRng, seq::SliceRandom, Rng, SeedableRng, }; @@ -218,12 +222,18 @@ pub fn generate_random_array(n: usize) -> Float32Array { Float32Array::from_iter_values(repeat_with(|| rng.gen::()).take(n)) } -/// Create a random float32 array where each element is uniformly distributed a +/// Create a random primitive array where each element is uniformly distributed a /// given range. -pub fn generate_random_array_with_range(n: usize, range: Range) -> Float32Array { +pub fn generate_random_array_with_range( + n: usize, + range: Range, +) -> PrimitiveArray +where + T::Native: SampleUniform, +{ let mut rng = StdRng::from_seed([13; 32]); let distribution = Uniform::new(range.start, range.end); - Float32Array::from_iter_values(repeat_with(|| distribution.sample(&mut rng)).take(n)) + PrimitiveArray::::from_iter_values(repeat_with(|| distribution.sample(&mut rng)).take(n)) } /// Create a random float32 array where each element is uniformly diff --git a/rust/lance/src/dataset/fragment/write.rs b/rust/lance/src/dataset/fragment/write.rs index 28edf4f21e..409758acb1 100644 --- a/rust/lance/src/dataset/fragment/write.rs +++ b/rust/lance/src/dataset/fragment/write.rs @@ -87,7 +87,6 @@ impl<'a> FragmentCreateBuilder<'a> { let obj_writer = object_store.create(&full_path).await?; let mut writer = lance_file::v2::writer::FileWriter::try_new( obj_writer, - full_path.to_string(), schema, FileWriterOptions::default(), )?; @@ -147,7 +146,11 @@ impl<'a> FragmentCreateBuilder<'a> { Self::validate_schema(&schema, stream.schema().as_ref())?; - let (object_store, base_path) = ObjectStore::from_uri(self.dataset_uri).await?; + let (object_store, base_path) = ObjectStore::from_uri_and_params( + self.dataset_uri, + ¶ms.store_params.clone().unwrap_or_default(), + ) + .await?; let filename = format!("{}.lance", Uuid::new_v4()); let mut fragment = Fragment::with_file_legacy(id, &filename, &schema, None); let full_path = base_path.child(DATA_DIR).child(filename.clone()); diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index ac821360fe..a5cd1c4af0 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -125,7 +125,7 @@ pub struct Scanner { dataset: Arc, /// The physical schema (before dynamic projection) that must be loaded from the table - phyical_columns: Schema, + physical_columns: Schema, /// The expressions for all the columns to be in the output /// Note: this doesn't include _distance, and _rowid @@ -179,6 +179,14 @@ pub struct Scanner { /// If set, this scanner serves only these fragments. fragments: Option>, + + /// Only search the data being indexed (weak consistency search). + /// + /// Default value is false. + /// + /// This is essentially a weak consistency search. Users can run index or optimize index + /// to make the index catch up with the latest data. + fast_search: bool, } fn escape_column_name(name: &str) -> String { @@ -194,7 +202,7 @@ impl Scanner { Self { dataset, - phyical_columns: projection, + physical_columns: projection, requested_output_expr: None, prefilter: false, filter: None, @@ -210,6 +218,7 @@ impl Scanner { with_row_address: false, ordered: true, fragments: None, + fast_search: false, } } @@ -297,7 +306,10 @@ impl Scanner { } output.insert(output_name.as_ref().to_string(), expr); } - self.phyical_columns = self.dataset.schema().project(&physical_cols)?; + + let physical_schema = self.physical_schema(true)?; + let schema_with_meta_columns = self.dataset.schema().merge(physical_schema.as_ref())?; + self.physical_columns = schema_with_meta_columns.project(&physical_cols)?; let mut output_cols = vec![]; for (name, _) in columns { @@ -497,6 +509,19 @@ impl Scanner { self } + /// Only search the data being indexed. + /// + /// Default value is false. + /// + /// This is essentially a weak consistency search, only on the indexed data. + pub fn fast_search(&mut self) -> &mut Self { + if let Some(q) = self.nearest.as_mut() { + q.use_index = true; + } + self.fast_search = true; + self + } + /// Apply a refine step to the vector search. /// /// A refine improves query accuracy but also makes search slower, by reading extra elements @@ -582,20 +607,14 @@ impl Scanner { } /// The schema of the Scanner from lance physical takes - pub(crate) fn physical_schema(&self) -> Result> { + pub(crate) fn physical_schema(&self, in_projection: bool) -> Result> { let mut extra_columns = vec![]; - if let Some(q) = self.nearest.as_ref() { - let vector_field = self.dataset.schema().field(&q.column).ok_or(Error::io( - format!("Column {} not found", q.column), - location!(), - ))?; - let vector_field = ArrowField::from(vector_field); - extra_columns.push(vector_field); + if self.nearest.as_ref().is_some() { extra_columns.push(ArrowField::new(DIST_COL, DataType::Float32, true)); }; - if self.with_row_id { + if self.with_row_id || in_projection { extra_columns.push(ROW_ID_FIELD.clone()); } @@ -604,10 +623,10 @@ impl Scanner { } let schema = if !extra_columns.is_empty() { - self.phyical_columns + self.physical_columns .merge(&ArrowSchema::new(extra_columns))? } else { - self.phyical_columns.clone() + self.physical_columns.clone() }; // drop metadata @@ -625,7 +644,7 @@ impl Scanner { } pub(crate) fn output_expr(&self) -> Result, String)>> { - let physical_schema = self.physical_schema()?; + let physical_schema = self.physical_schema(false)?; // reset the id ordering becuase we will take the batch after projection let physical_schema = ArrowSchema::new( @@ -638,7 +657,7 @@ impl Scanner { let df_schema = Arc::new(DFSchema::try_from(physical_schema.clone())?); let project_star = self - .phyical_columns + .physical_columns .fields .iter() .map(|f| { @@ -669,17 +688,17 @@ impl Scanner { let mut output_expr = output_expr.unwrap_or(Ok(project_star))?; // distance goes before the row_id column - if self.nearest.is_some() { + if self.nearest.is_some() && output_expr.iter().all(|(_, name)| name != DIST_COL) { let vector_expr = expressions::col(DIST_COL, &physical_schema)?; output_expr.push((vector_expr, DIST_COL.to_string())); } - if self.with_row_id { + if self.with_row_id && output_expr.iter().all(|(_, name)| name != ROW_ID) { let row_id_expr = expressions::col(ROW_ID, &physical_schema)?; output_expr.push((row_id_expr, ROW_ID.to_string())); } - if self.with_row_address { + if self.with_row_address && output_expr.iter().all(|(_, name)| name != ROW_ADDR) { let row_addr_expr = expressions::col(ROW_ADDR, &physical_schema)?; output_expr.push((row_addr_expr, ROW_ADDR.to_string())); } @@ -808,7 +827,7 @@ impl Scanner { /// -> Take(remaining_cols) -> Projection() /// ``` /// - /// In general, a plan has 4 stages: + /// In general, a plan has 5 stages: /// /// 1. Source (from dataset Scan or from index, may include prefilter) /// 2. Filter @@ -816,7 +835,7 @@ impl Scanner { /// 4. Limit / Offset /// 5. Take remaining columns / Projection pub async fn create_plan(&self) -> Result> { - if self.phyical_columns.fields.is_empty() && !self.with_row_id && !self.with_row_address { + if self.physical_columns.fields.is_empty() && !self.with_row_id && !self.with_row_address { return Err(Error::InvalidInput { source: "no columns were selected and with_row_id is false, there is nothing to scan" @@ -889,7 +908,7 @@ impl Scanner { }; match (&filter_plan.index_query, &mut filter_plan.refine_expr) { (Some(index_query), None) => { - self.scalar_indexed_scan(&self.phyical_columns, index_query) + self.scalar_indexed_scan(&self.physical_columns, index_query) .await? } // TODO: support combined pushdown and scalar index scan @@ -913,7 +932,7 @@ impl Scanner { let columns = filter_plan.refine_columns(); Arc::new(self.dataset.schema().project(&columns)?) } else { - Arc::new(self.phyical_columns.clone()) + Arc::new(self.physical_columns.clone()) }; self.scan(with_row_id, self.with_row_address, false, schema) } @@ -989,12 +1008,11 @@ impl Scanner { } // Stage 5: take remaining columns required for projection - let physical_schema = self.physical_schema()?; + let physical_schema = self.physical_schema(false)?; let remaining_schema = physical_schema.exclude(plan.schema().as_ref())?; if !remaining_schema.fields.is_empty() { plan = self.take(plan, &remaining_schema, self.batch_readahead)?; } - // Stage 6: physical projection -- reorder physical columns needed before final projection let output_arrow_schema = physical_schema.as_ref().into(); if plan.schema().as_ref() != &output_arrow_schema { @@ -1021,7 +1039,7 @@ impl Scanner { return Err(Error::io("No nearest query".to_string(), location!())); }; - // Santity check + // Sanity check let schema = self.dataset.schema(); if let Some(field) = schema.field(&q.column) { match field.data_type() { @@ -1064,9 +1082,10 @@ impl Scanner { let deltas = self.dataset.load_indices_by_name(&index.name).await?; let ann_node = self.ann(q, &deltas, filter_plan).await?; // _distance, _rowid - let with_vector = self.dataset.schema().project(&[&q.column])?; - let knn_node_with_vector = self.take(ann_node, &with_vector, self.batch_readahead)?; let mut knn_node = if q.refine_factor.is_some() { + let with_vector = self.dataset.schema().project(&[&q.column])?; + let knn_node_with_vector = + self.take(ann_node, &with_vector, self.batch_readahead)?; // TODO: now we just open an index to get its metric type. let idx = self .dataset @@ -1076,10 +1095,12 @@ impl Scanner { q.metric_type = idx.metric_type(); self.flat_knn(knn_node_with_vector, &q)? } else { - knn_node_with_vector + ann_node }; // vector, _distance, _rowid - knn_node = self.knn_combined(&q, index, knn_node, filter_plan).await?; + if !self.fast_search { + knn_node = self.knn_combined(q, index, knn_node, filter_plan).await?; + } Ok(knn_node) } else { @@ -1108,14 +1129,21 @@ impl Scanner { /// Combine ANN results with KNN results for data appended after index creation async fn knn_combined( &self, - q: &&Query, + q: &Query, index: &Index, - knn_node: Arc, + mut knn_node: Arc, filter_plan: &FilterPlan, ) -> Result> { - // Check if we've created new versions since the index + // Check if we've created new versions since the index was built. let unindexed_fragments = self.dataset.unindexed_fragments(&index.name).await?; if !unindexed_fragments.is_empty() { + // If the vector column is not present, we need to take the vector column, so + // that the distance value is comparable with the flat search ones. + if knn_node.schema().column_with_name(&q.column).is_none() { + let with_vector = self.dataset.schema().project(&[&q.column])?; + knn_node = self.take(knn_node, &with_vector, self.batch_readahead)?; + } + let mut columns = vec![q.column.clone()]; if let Some(expr) = filter_plan.full_expr.as_ref() { let filter_columns = Planner::column_names_in_expr(expr); @@ -1361,7 +1389,7 @@ impl Scanner { Ok(Arc::new(LancePushdownScanExec::try_new( self.dataset.clone(), fragments, - self.phyical_columns.clone().into(), + self.physical_columns.clone().into(), predicate, config, )?)) @@ -3907,7 +3935,7 @@ mod test { &dataset.dataset, |scan| scan.nearest("vec", &q, 42), "Projection: fields=[i, s, vec, _distance] - Take: columns=\"_distance, _rowid, vec, i, s\" + Take: columns=\"_distance, _rowid, i, s, vec\" SortExec: TopK(fetch=42), expr=... ANNSubIndex: name=..., k=42, deltas=1 ANNIvfPartition: uuid=..., nprobes=1, deltas=1", @@ -3953,9 +3981,9 @@ mod test { .with_row_id()) }, "Projection: fields=[s, vec, _distance, _rowid] - Take: columns=\"_distance, _rowid, vec, i, s\" - FilterExec: i@3 > 10 - Take: columns=\"_distance, _rowid, vec, i\" + Take: columns=\"_distance, _rowid, i, s, vec\" + FilterExec: i@2 > 10 + Take: columns=\"_distance, _rowid, i\" SortExec: TopK(fetch=17), expr=... ANNSubIndex: name=..., k=17, deltas=1 ANNIvfPartition: uuid=..., nprobes=1, deltas=1", @@ -3972,7 +4000,7 @@ mod test { .prefilter(true)) }, "Projection: fields=[i, s, vec, _distance] - Take: columns=\"_distance, _rowid, vec, i, s\" + Take: columns=\"_distance, _rowid, i, s, vec\" SortExec: TopK(fetch=17), expr=... ANNSubIndex: name=..., k=17, deltas=1 ANNIvfPartition: uuid=..., nprobes=1, deltas=1 @@ -4079,7 +4107,7 @@ mod test { .prefilter(true)) }, "Projection: fields=[i, s, vec, _distance] - Take: columns=\"_distance, _rowid, vec, i, s\" + Take: columns=\"_distance, _rowid, i, s, vec\" SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 ANNIvfPartition: uuid=..., nprobes=1, deltas=1 @@ -4197,4 +4225,72 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_fast_search_plan() { + // Create a vector dataset + let mut dataset = TestVectorDataset::new(false, true).await.unwrap(); + dataset.make_vector_index().await.unwrap(); + dataset.append_new_data().await.unwrap(); + + let q: Float32Array = (32..64).map(|v| v as f32).collect(); + + assert_plan_equals( + &dataset.dataset, + |scan| { + scan.nearest("vec", &q, 32)? + .fast_search() + .project(&["_rowid", "_distance"]) + }, + "Projection: fields=[_rowid, _distance] + SortExec: TopK(fetch=32), expr=[_distance@0 ASC NULLS LAST] + ANNSubIndex: name=idx, k=32, deltas=1 + ANNIvfPartition: uuid=..., nprobes=1, deltas=1", + ) + .await + .unwrap(); + + assert_plan_equals( + &dataset.dataset, + |scan| { + scan.nearest("vec", &q, 33)? + .fast_search() + .with_row_id() + .project(&["_rowid", "_distance"]) + }, + "Projection: fields=[_rowid, _distance] + SortExec: TopK(fetch=33), expr=[_distance@0 ASC NULLS LAST] + ANNSubIndex: name=idx, k=33, deltas=1 + ANNIvfPartition: uuid=..., nprobes=1, deltas=1", + ) + .await + .unwrap(); + + // Not `fast_scan` case + assert_plan_equals( + &dataset.dataset, + |scan| { + scan.nearest("vec", &q, 34)? + .with_row_id() + .project(&["_rowid", "_distance"]) + }, + "Projection: fields=[_rowid, _distance] + FilterExec: _distance@2 IS NOT NULL + SortExec: TopK(fetch=34), expr=[_distance@2 ASC NULLS LAST] + KNNVectorDistance: metric=l2 + RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2 + UnionExec + Projection: fields=[_distance, _rowid, vec] + FilterExec: _distance@2 IS NOT NULL + SortExec: TopK(fetch=34), expr=[_distance@2 ASC NULLS LAST] + KNNVectorDistance: metric=l2 + LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false + Take: columns=\"_distance, _rowid, vec\" + SortExec: TopK(fetch=34), expr=[_distance@0 ASC NULLS LAST] + ANNSubIndex: name=idx, k=34, deltas=1 + ANNIvfPartition: uuid=..., nprobes=1, deltas=1", + ) + .await + .unwrap(); + } } diff --git a/rust/lance/src/dataset/write.rs b/rust/lance/src/dataset/write.rs index 95f32dcbeb..2caa25c725 100644 --- a/rust/lance/src/dataset/write.rs +++ b/rust/lance/src/dataset/write.rs @@ -308,39 +308,46 @@ impl GenericWriter for (FileWriter, String } } +struct V2WriterAdapter { + writer: v2::writer::FileWriter, + path: String, +} + #[async_trait::async_trait] -impl GenericWriter for v2::writer::FileWriter { +impl GenericWriter for V2WriterAdapter { fn multipart_id(&self) -> &str { - self.multipart_id() + self.writer.multipart_id() } async fn write(&mut self, batches: &[RecordBatch]) -> Result<()> { for batch in batches { - self.write_batch(batch).await?; + self.writer.write_batch(batch).await?; } Ok(()) } async fn tell(&mut self) -> Result { - Ok(self.tell().await?) + Ok(self.writer.tell().await?) } async fn finish(&mut self) -> Result<(u32, DataFile)> { let field_ids = self + .writer .field_id_to_column_indices() .iter() .map(|(field_id, _)| *field_id) .collect::>(); let column_indices = self + .writer .field_id_to_column_indices() .iter() .map(|(_, column_index)| *column_index) .collect::>(); let data_file = DataFile::new( - self.path(), + std::mem::take(&mut self.path), field_ids, column_indices, MAJOR_VERSION as u32, MINOR_VERSION_NEXT as u32, ); - let num_rows = self.finish().await? as u32; + let num_rows = self.writer.finish().await? as u32; Ok((num_rows, data_file)) } } @@ -368,12 +375,13 @@ pub async fn open_writer( )) } else { let writer = object_store.create(&full_path).await?; - Box::new(v2::writer::FileWriter::try_new( - writer, - filename, - schema.clone(), - FileWriterOptions::default(), - )?) as Box + let file_writer = + v2::writer::FileWriter::try_new(writer, schema.clone(), FileWriterOptions::default())?; + let writer_adapter = V2WriterAdapter { + writer: file_writer, + path: filename, + }; + Box::new(writer_adapter) as Box }; Ok(writer) } diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index d536b6a936..8ac993a298 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -33,6 +33,7 @@ use lance_table::format::Index as IndexMetadata; use lance_table::format::{Fragment, SelfDescribingFileReader}; use lance_table::io::manifest::read_manifest_indexes; use roaring::RoaringBitmap; +use scalar::ScalarIndexParams; use serde_json::json; use snafu::{location, Location}; use tracing::instrument; @@ -201,7 +202,14 @@ impl DatasetIndexExt for Dataset { let index_id = Uuid::new_v4(); match (index_type, params.index_name()) { (IndexType::Scalar, LANCE_SCALAR_INDEX) => { - build_scalar_index(self, column, &index_id.to_string()).await?; + let params = params + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::Index { + message: "Scalar index type must take a ScalarIndexParams".to_string(), + location: location!(), + })?; + build_scalar_index(self, column, &index_id.to_string(), params).await?; } (IndexType::Vector, LANCE_VECTOR_INDEX) => { // Vector index params. diff --git a/rust/lance/src/index/prefilter.rs b/rust/lance/src/index/prefilter.rs index a7a8df9cd9..9aa84fd188 100644 --- a/rust/lance/src/index/prefilter.rs +++ b/rust/lance/src/index/prefilter.rs @@ -47,7 +47,7 @@ pub struct DatasetPreFilter { pub(super) deleted_ids: Option>>>, pub(super) filtered_ids: Option>>, // When the tasks are finished this is the combined filter - pub(super) final_mask: Mutex>, + pub(super) final_mask: Mutex>>, } impl DatasetPreFilter { @@ -233,7 +233,7 @@ impl PreFilter for DatasetPreFilter { if let Some(deleted_ids) = &self.deleted_ids { combined = combined & (*deleted_ids.get_ready()).clone(); } - combined + Arc::new(combined) }); Ok(()) @@ -251,11 +251,14 @@ impl PreFilter for DatasetPreFilter { /// This method must be called after `wait_for_ready` #[instrument(level = "debug", skip_all)] fn filter_row_ids<'a>(&self, row_ids: Box + 'a>) -> Vec { - let final_mask = self.final_mask.lock().unwrap(); - final_mask + let final_mask = self + .final_mask + .lock() + .unwrap() .get() .expect("filter_row_ids called without call to wait_for_ready") - .selected_indices(row_ids) + .clone(); + final_mask.selected_indices(row_ids) } } diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 1e9ff03a85..9acb722d7f 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -11,6 +11,7 @@ use datafusion::physical_plan::SendableRecordBatchStream; use lance_datafusion::{chunker::chunk_concat_stream, exec::LanceExecutionOptions}; use lance_index::{ scalar::{ + bitmap::{train_bitmap_index, BitmapIndex, BITMAP_LOOKUP_NAME}, btree::{train_btree_index, BTreeIndex, BtreeTrainingSource}, flat::FlatIndexMetadata, lance_format::LanceIndexStore, @@ -32,8 +33,16 @@ use super::IndexParams; pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index"; +pub enum ScalarIndexType { + BTree, + Bitmap, +} + #[derive(Default)] -pub struct ScalarIndexParams {} +pub struct ScalarIndexParams { + /// If set then always use the given index type and skip auto-detection + pub force_index_type: Option, +} impl IndexParams for ScalarIndexParams { fn as_any(&self) -> &dyn std::any::Any { @@ -78,9 +87,14 @@ impl BtreeTrainingSource for TrainingRequest { } } -/// Build a Vector Index -#[instrument(level = "debug", skip(dataset))] -pub async fn build_scalar_index(dataset: &Dataset, column: &str, uuid: &str) -> Result<()> { +/// Build a Scalar Index +#[instrument(level = "debug", skip_all)] +pub async fn build_scalar_index( + dataset: &Dataset, + column: &str, + uuid: &str, + params: &ScalarIndexParams, +) -> Result<()> { let training_request = Box::new(TrainingRequest { dataset: Arc::new(dataset.clone()), column: column.to_string(), @@ -97,15 +111,27 @@ pub async fn build_scalar_index(dataset: &Dataset, column: &str, uuid: &str) -> location: location!(), }); } - let flat_index_trainer = FlatIndexMetadata::new(field.data_type()); let index_store = LanceIndexStore::from_dataset(dataset, uuid); - train_btree_index(training_request, &flat_index_trainer, &index_store).await + match params.force_index_type { + Some(ScalarIndexType::Bitmap) => train_bitmap_index(training_request, &index_store).await, + _ => { + let flat_index_trainer = FlatIndexMetadata::new(field.data_type()); + train_btree_index(training_request, &flat_index_trainer, &index_store).await + } + } } pub async fn open_scalar_index(dataset: &Dataset, uuid: &str) -> Result> { let index_store = Arc::new(LanceIndexStore::from_dataset(dataset, uuid)); - // Currently we assume all scalar indices are btree indices. In the future, if this is not the - // case, we may need to store a metadata file in the index directory with scalar index metadata - let btree_index = BTreeIndex::load(index_store).await?; - Ok(btree_index as Arc) + let index_dir = dataset.indices_dir().child(uuid); + // This works at the moment, since we only have two index types, may need to introduce better + // detection method in the future. + let bitmap_page_lookup = index_dir.child(BITMAP_LOOKUP_NAME); + if dataset.object_store.exists(&bitmap_page_lookup).await? { + let bitmap_index = BitmapIndex::load(index_store).await?; + Ok(bitmap_index as Arc) + } else { + let btree_index = BTreeIndex::load(index_store).await?; + Ok(btree_index as Arc) + } } diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 23c77f32a6..acd39dbbb5 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -16,13 +16,12 @@ mod utils; #[cfg(test)] mod fixture_test; -use arrow::datatypes::Float32Type; use builder::IvfIndexBuilder; use lance_file::reader::FileReader; use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::storage::IvfModel; -use lance_index::vector::pq::ProductQuantizerImpl; +use lance_index::vector::pq::ProductQuantizer; use lance_index::vector::v3::shuffler::IvfShuffler; use lance_index::vector::{ hnsw::{ @@ -260,6 +259,7 @@ pub(crate) async fn build_vector_index( location: location!(), }); }; + build_ivf_pq_index( dataset, column, @@ -422,7 +422,7 @@ pub(crate) async fn open_vector_index( location: location!(), }); }; - let pq = lance_index::vector::pq::builder::from_proto(pq_proto, metric_type)?; + let pq = ProductQuantizer::from_proto(pq_proto, metric_type)?; last_stage = Some(Arc::new(PQIndex::new(pq, metric_type))); } Some(Stage::Diskann(_)) => { @@ -473,7 +473,7 @@ pub(crate) async fn open_vector_index_v2( let ivf_data = IvfModel::load(&reader).await?; let options = HNSWIndexOptions { use_residual: true }; - let hnsw = HNSWIndex::>::try_new( + let hnsw = HNSWIndex::::try_new( reader.object_reader.clone(), aux_reader.into(), options, diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 34199cd092..49ae55a084 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -412,7 +412,6 @@ impl IvfIndexBuilde let writer = object_store.create(&path).await?; let mut writer = FileWriter::try_new( writer, - path.to_string(), storage.schema().as_ref().try_into()?, Default::default(), )?; @@ -424,19 +423,14 @@ impl IvfIndexBuilde // build the sub index, with in-memory storage let index_len = { - let distance_type = match self.distance_type { - DistanceType::Cosine | DistanceType::Dot => DistanceType::L2, - _ => self.distance_type, - }; let vectors = batch[&self.column].as_fixed_size_list(); - let flat_storage = FlatStorage::new(vectors.clone(), distance_type); + let flat_storage = FlatStorage::new(vectors.clone(), self.distance_type); let sub_index = S::index_vectors(&flat_storage, self.sub_index_params.clone())?; let path = self.temp_dir.child(format!("index_part{}", part_id)); let writer = object_store.create(&path).await?; let index_batch = sub_index.to_batch()?; let mut writer = FileWriter::try_new( writer, - path.to_string(), index_batch.schema_ref().as_ref().try_into()?, Default::default(), )?; @@ -467,7 +461,6 @@ impl IvfIndexBuilde let mut storage_writer = None; let mut index_writer = FileWriter::try_new( self.dataset.object_store().create(&index_path).await?, - index_path.to_string(), S::schema().as_ref().try_into()?, Default::default(), )?; @@ -504,7 +497,6 @@ impl IvfIndexBuilde if storage_writer.is_none() { storage_writer = Some(FileWriter::try_new( self.dataset.object_store().create(&storage_path).await?, - storage_path.to_string(), batch.schema_ref().as_ref().try_into()?, Default::default(), )?); @@ -616,6 +608,7 @@ impl IvfIndexBuilde mod tests { use std::{collections::HashMap, ops::Range, sync::Arc}; + use arrow::datatypes::Float32Type; use arrow_array::{FixedSizeListArray, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use lance_arrow::FixedSizeListArrayExt; @@ -642,7 +635,7 @@ mod tests { test_uri: &str, range: Range, ) -> (Dataset, Arc) { - let vectors = generate_random_array_with_range(1000 * DIM, range); + let vectors = generate_random_array_with_range::(1000 * DIM, range); let metadata: HashMap = vec![("test".to_string(), "ivf_pq".to_string())] .into_iter() .collect(); diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index c41ede6d80..6558686fe1 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -24,6 +24,7 @@ use futures::{ stream::{self, StreamExt}, TryStreamExt, }; +use io::write_hnsw_quantization_index_partitions; use lance_arrow::*; use lance_core::{datatypes::Field, traits::DatasetTakeRows, Error, Result, ROW_ID_FIELD}; use lance_file::{ @@ -57,14 +58,11 @@ use lance_io::{ stream::RecordBatchStream, traits::{Reader, WriteExt, Writer}, }; +use lance_linalg::distance::{DistanceType, Dot, MetricType, L2}; use lance_linalg::{ distance::Normalize, kernels::{normalize_arrow, normalize_fsl}, }; -use lance_linalg::{ - distance::{DistanceType, Dot, MetricType, L2}, - MatrixView, -}; use log::info; use object_store::path::Path; use rand::{rngs::SmallRng, SeedableRng}; @@ -75,9 +73,7 @@ use snafu::{location, Location}; use tracing::instrument; use uuid::Uuid; -use self::io::write_hnsw_quantization_index_partitions; - -use super::builder::IvfIndexBuilder; +use super::{builder::IvfIndexBuilder, utils::PartitionLoadLock}; use super::{ pq::{build_pq_model, PQIndex}, utils::maybe_sample_training_data, @@ -94,7 +90,7 @@ use crate::{ session::Session, }; -mod builder; +pub mod builder; mod io; pub mod v2; @@ -110,6 +106,8 @@ pub struct IVFIndex { /// Index in each partition. sub_index: Arc, + partition_locks: PartitionLoadLock, + metric_type: MetricType, // The session cache holds an Arc to this object so we need to @@ -143,6 +141,8 @@ impl IVFIndex { location: location!(), }); } + + let num_partitions = ivf.num_partitions(); Ok(Self { uuid: uuid.to_owned(), session: Arc::downgrade(&session), @@ -150,6 +150,7 @@ impl IVFIndex { reader, sub_index, metric_type, + partition_locks: PartitionLoadLock::new(num_partitions), }) } @@ -174,32 +175,40 @@ impl IVFIndex { let part_index = if let Some(part_idx) = session.index_cache.get_vector(&cache_key) { part_idx } else { - if partition_id >= self.ivf.num_partitions() { - return Err(Error::Index { - message: format!( - "partition id {} is out of range of {} partitions", - partition_id, - self.ivf.num_partitions() - ), - location: location!(), - }); - } + let mtx = self.partition_locks.get_partition_mutex(partition_id); + let _guard = mtx.lock().await; + // check the cache again, as the partition may have been loaded by another + // thread that held the lock on loading the partition + if let Some(part_idx) = session.index_cache.get_vector(&cache_key) { + part_idx + } else { + if partition_id >= self.ivf.num_partitions() { + return Err(Error::Index { + message: format!( + "partition id {} is out of range of {} partitions", + partition_id, + self.ivf.num_partitions() + ), + location: location!(), + }); + } - let range = self.ivf.row_range(partition_id); - let idx = self - .sub_index - .load_partition( - self.reader.clone(), - range.start, - range.end - range.start, - partition_id, - ) - .await?; - let idx: Arc = idx.into(); - if write_cache { - session.index_cache.insert_vector(&cache_key, idx.clone()); + let range = self.ivf.row_range(partition_id); + let idx = self + .sub_index + .load_partition( + self.reader.clone(), + range.start, + range.end - range.start, + partition_id, + ) + .await?; + let idx: Arc = idx.into(); + if write_cache { + session.index_cache.insert_vector(&cache_key, idx.clone()); + } + idx } - idx }; Ok(part_index) } @@ -236,7 +245,7 @@ pub(crate) async fn optimize_vector_indices( existing_indices: &[Arc], options: &OptimizeOptions, ) -> Result<(Uuid, usize)> { - // Senity check the indices + // Sanity check the indices if existing_indices.is_empty() { return Err(Error::Index { message: "optimizing vector index: no existing index found".to_string(), @@ -589,16 +598,7 @@ async fn optimize_ivf_hnsw_indices( let quantization_metadata = match &quantizer { Quantizer::Flat(_) => None, Quantizer::Product(pq) => { - let mat = MatrixView::::new( - Arc::new( - pq.codebook_as_fsl() - .values() - .as_primitive::() - .clone(), - ), - pq.dimension(), - ); - let codebook_tensor = pb::Tensor::from(&mat); + let codebook_tensor = pb::Tensor::try_from(&pq.codebook)?; let codebook_pos = aux_writer.tell().await?; aux_writer .object_writer @@ -903,7 +903,7 @@ pub struct IvfPQIndexMetadata { pub(crate) ivf: IvfModel, /// Product Quantizer - pub(crate) pq: Arc, + pub(crate) pq: ProductQuantizer, /// Transforms to be applied before search. transforms: Vec, @@ -931,9 +931,9 @@ impl TryFrom<&IvfPQIndexMetadata> for pb::Index { )?)), }, pb::VectorIndexStage { - stage: Some(pb::vector_index_stage::Stage::Pq( - idx.pq.as_ref().try_into()?, - )), + stage: Some(pb::vector_index_stage::Stage::Pq(pb::Pq::try_from( + &idx.pq, + )?)), }, ]); @@ -1044,7 +1044,7 @@ fn sanity_check_params(ivf: &IvfBuildParams, pq: &PQBuildParams) -> Result<()> { /// /// Visibility: pub(super) for testing #[instrument(level = "debug", skip_all, name = "build_ivf_model")] -pub(super) async fn build_ivf_model( +pub async fn build_ivf_model( dataset: &Dataset, column: &str, dim: usize, @@ -1103,7 +1103,7 @@ async fn build_ivf_model_and_pq( metric_type: MetricType, ivf_params: &IvfBuildParams, pq_params: &PQBuildParams, -) -> Result<(IvfModel, Arc)> { +) -> Result<(IvfModel, ProductQuantizer)> { sanity_check_params(ivf_params, pq_params)?; info!( @@ -1154,7 +1154,11 @@ async fn load_precomputed_partitions_if_available( match &ivf_params.precomputed_partitons_file { Some(file) => { info!("Loading precomputed partitions from file: {}", file); - let ds = DatasetBuilder::from_uri(file).load().await?; + let mut builder = DatasetBuilder::from_uri(file); + if let Some(storage_options) = &ivf_params.storage_options { + builder = builder.with_storage_options(storage_options.clone()); + } + let ds = builder.load().await?; let stream = ds.scan().try_into_stream().await?; Ok(Some( load_precomputed_partitions(stream, ds.count_rows(None).await?).await?, @@ -1362,7 +1366,7 @@ async fn write_ivf_pq_file( uuid: &str, transformers: &[Box], mut ivf: IvfModel, - pq: Arc, + pq: ProductQuantizer, metric_type: MetricType, stream: impl RecordBatchStream + Unpin + 'static, precomputed_partitons: Option>, @@ -1402,7 +1406,7 @@ async fn write_ivf_pq_file( let metadata = IvfPQIndexMetadata { name: index_name.to_string(), column: column.to_string(), - dimension: pq.dimension() as u32, + dimension: pq.dimension as u32, dataset_version: dataset.version().version, metric_type, ivf, @@ -1486,16 +1490,7 @@ async fn write_ivf_hnsw_file( let quantization_metadata = match &quantizer { Quantizer::Flat(_) => None, Quantizer::Product(pq) => { - let mat = MatrixView::::new( - Arc::new( - pq.codebook_as_fsl() - .values() - .as_primitive::() - .clone(), - ), - pq.dimension(), - ); - let codebook_tensor = pb::Tensor::from(&mat); + let codebook_tensor = pb::Tensor::try_from(&pq.codebook)?; let codebook_pos = aux_writer.tell().await?; aux_writer .object_writer @@ -1632,6 +1627,7 @@ mod tests { use lance_core::ROW_ID; use lance_index::vector::sq::builder::SQBuildParams; use lance_linalg::distance::l2_distance_batch; + use lance_linalg::MatrixView; use lance_testing::datagen::{ generate_random_array, generate_random_array_with_range, generate_random_array_with_seed, generate_scaled_random_array, sample_without_replacement, @@ -1856,7 +1852,7 @@ mod tests { test_uri: &str, range: Range, ) -> (Dataset, Arc) { - let vectors = generate_random_array_with_range(1000 * DIM, range); + let vectors = generate_random_array_with_range::(1000 * DIM, range); let metadata: HashMap = vec![("test".to_string(), "ivf_pq".to_string())] .into_iter() .collect(); @@ -2654,7 +2650,7 @@ mod tests { // PQ code is on residual space pq_idx .pq - .codebook_as_fsl() + .codebook .values() .as_primitive::() .values() diff --git a/rust/lance/src/index/vector/ivf/builder.rs b/rust/lance/src/index/vector/ivf/builder.rs index 5b06d60619..f3bf9e2e34 100644 --- a/rust/lance/src/index/vector/ivf/builder.rs +++ b/rust/lance/src/index/vector/ivf/builder.rs @@ -36,7 +36,7 @@ pub(super) async fn build_partitions( data: impl RecordBatchStream + Unpin + 'static, column: &str, ivf: &mut IvfModel, - pq: Arc, + pq: ProductQuantizer, metric_type: MetricType, part_range: Range, precomputed_partitons: Option>, diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index 0e7e4ebd7e..eaf0014a7a 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -202,7 +202,7 @@ pub(super) async fn write_pq_partitions( let fsl = Arc::new( FixedSizeListArray::try_new_from_values( pq_code.as_ref().clone(), - pq_index.pq.num_sub_vectors() as i32, + pq_index.pq.code_dim() as i32, ) .unwrap(), ); @@ -523,7 +523,7 @@ async fn build_and_write_pq_storage( metric_type: MetricType, row_ids: Arc, code_array: Vec>, - pq: Arc, + pq: ProductQuantizer, mut writer: FileWriter, ) -> Result<()> { let storage = spawn_cpu(move || { diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 0e913ec02a..07fd3fbc08 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -47,11 +47,15 @@ use moka::sync::Cache; use object_store::path::Path; use prost::Message; use roaring::RoaringBitmap; +use serde_json::json; use snafu::{location, Location}; use tracing::instrument; use crate::{ - index::{vector::VectorIndex, PreFilter}, + index::{ + vector::{utils::PartitionLoadLock, VectorIndex}, + PreFilter, + }, session::Session, }; @@ -78,6 +82,8 @@ pub struct IVFIndex { /// Index in each partition. partition_cache: Cache>>, + partition_locks: PartitionLoadLock, + distance_type: DistanceType, // The session cache holds an Arc to this object so we need to @@ -167,12 +173,14 @@ impl IVFIndex { .await?; let storage = IvfQuantizationStorage::try_new(storage_reader).await?; + let num_partitions = ivf.num_partitions(); Ok(Self { uuid, ivf, reader: index_reader, storage, partition_cache: Cache::new(DEFAULT_INDEX_CACHE_SIZE as u64), + partition_locks: PartitionLoadLock::new(num_partitions), sub_index_metadata, distance_type, session, @@ -201,38 +209,48 @@ impl IVFIndex { }); } - let schema = Arc::new(self.reader.schema().as_ref().into()); - let batch = match self.reader.metadata().num_rows { - 0 => RecordBatch::new_empty(schema), - _ => { - let batches = self - .reader - .read_stream( - ReadBatchParams::Range(self.ivf.row_range(partition_id)), - u32::MAX, - 1, - FilterExpression::no_filter(), - )? - .try_collect::>() - .await?; - concat_batches(&schema, batches.iter())? + let mtx = self.partition_locks.get_partition_mutex(partition_id); + let _guard = mtx.lock().await; + + // check the cache again, as the partition may have been loaded by another + // thread that held the lock on loading the partition + if let Some(part_idx) = self.partition_cache.get(&cache_key) { + part_idx + } else { + let schema = Arc::new(self.reader.schema().as_ref().into()); + let batch = match self.reader.metadata().num_rows { + 0 => RecordBatch::new_empty(schema), + _ => { + let batches = self + .reader + .read_stream( + ReadBatchParams::Range(self.ivf.row_range(partition_id)), + u32::MAX, + 1, + FilterExpression::no_filter(), + )? + .try_collect::>() + .await?; + concat_batches(&schema, batches.iter())? + } + }; + let batch = batch.add_metadata( + S::metadata_key().to_owned(), + self.sub_index_metadata[partition_id].clone(), + )?; + let idx = S::load(batch)?; + let storage = self.load_partition_storage(partition_id).await?; + let partition_entry = Arc::new(PartitionEntry { + index: idx, + storage, + }); + if write_cache { + self.partition_cache + .insert(cache_key.clone(), partition_entry.clone()); } - }; - let batch = batch.add_metadata( - S::metadata_key().to_owned(), - self.sub_index_metadata[partition_id].clone(), - )?; - let idx = S::load(batch)?; - let storage = self.load_partition_storage(partition_id).await?; - let partition_entry = Arc::new(PartitionEntry { - index: idx, - storage, - }); - if write_cache { - self.partition_cache - .insert(cache_key.clone(), partition_entry.clone()); + + partition_entry } - partition_entry }; Ok(part_entry) @@ -304,7 +322,11 @@ impl Index for IVFIndex, ) -> (Dataset, Arc) { - let vectors = generate_random_array_with_range(1000 * DIM, range); + let vectors = generate_random_array_with_range::(1000 * DIM, range); let metadata: HashMap = vec![("test".to_string(), "ivf_pq".to_string())] .into_iter() .collect(); @@ -721,4 +743,47 @@ mod tests { assert_eq!(index["sub_index"]["index_type"].as_str().unwrap(), "HNSW"); } } + + #[tokio::test] + async fn test_index_stats_empty_partition() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let nlist = 1000; + let (mut dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; + + let ivf_params = IvfBuildParams::new(nlist); + let sq_params = SQBuildParams::default(); + let hnsw_params = HnswBuildParams::default(); + let params = VectorIndexParams::with_ivf_hnsw_sq_params( + DistanceType::L2, + ivf_params, + hnsw_params, + sq_params, + ); + + dataset + .create_index( + &["vector"], + IndexType::Vector, + Some("test_index".to_owned()), + ¶ms, + true, + ) + .await + .unwrap(); + + let stats = dataset.index_statistics("test_index").await.unwrap(); + let stats: serde_json::Value = serde_json::from_str(stats.as_str()).unwrap(); + + assert_eq!(stats["index_type"].as_str().unwrap(), "IVF_HNSW_SQ"); + for index in stats["indices"].as_array().unwrap() { + assert_eq!(index["index_type"].as_str().unwrap(), "IVF_HNSW_SQ"); + assert_eq!( + index["num_partitions"].as_number().unwrap(), + &serde_json::Number::from(nlist) + ); + assert_eq!(index["sub_index"]["index_type"].as_str().unwrap(), "HNSW"); + } + } } diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 81ece189f0..af846ea44d 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use std::{any::Any, collections::HashMap}; use arrow::compute::concat; -use arrow_array::types::{Float16Type, Float32Type, Float64Type}; use arrow_array::UInt32Array; use arrow_array::{ cast::{as_primitive_array, AsArray}, @@ -36,7 +35,7 @@ use snafu::{location, Location}; use tracing::{instrument, span, Level}; // Re-export -pub use lance_index::vector::pq::{PQBuildParams, ProductQuantizerImpl}; +pub use lance_index::vector::pq::PQBuildParams; use lance_linalg::kernels::normalize_fsl; use super::VectorIndex; @@ -50,7 +49,7 @@ use crate::{Error, Result}; #[derive(Clone)] pub struct PQIndex { /// Product quantizer. - pub pq: Arc, + pub pq: ProductQuantizer, /// PQ code pub code: Option>, @@ -83,8 +82,8 @@ impl std::fmt::Debug for PQIndex { write!( f, "PQ(m={}, nbits={}, {})", - self.pq.num_sub_vectors(), - self.pq.num_bits(), + self.pq.code_dim(), + self.pq.num_bits, self.metric_type ) } @@ -92,7 +91,7 @@ impl std::fmt::Debug for PQIndex { impl PQIndex { /// Load a PQ index (page) from the disk. - pub(crate) fn new(pq: Arc, metric_type: MetricType) -> Self { + pub(crate) fn new(pq: ProductQuantizer, metric_type: MetricType) -> Self { Self { code: None, row_ids: None, @@ -145,9 +144,9 @@ impl Index for PQIndex { fn statistics(&self) -> Result { Ok(json!({ "index_type": "PQ", - "nbits": self.pq.num_bits(), - "num_sub_vectors": self.pq.num_sub_vectors(), - "dimension": self.pq.dimension(), + "nbits": self.pq.num_bits, + "num_sub_vectors": self.pq.code_dim(), + "dimension": self.pq.dimension, "metric_type": self.metric_type.to_string(), })) } @@ -190,7 +189,7 @@ impl VectorIndex for PQIndex { let pq = self.pq.clone(); let query = query.clone(); - let num_sub_vectors = self.pq.num_sub_vectors() as i32; + let num_sub_vectors = self.pq.code_dim() as i32; spawn_cpu(move || { let (code, row_ids) = if pre_filter.is_empty() { Ok((code, row_ids)) @@ -245,7 +244,7 @@ impl VectorIndex for PQIndex { offset: usize, length: usize, ) -> Result> { - let pq_code_length = self.pq.num_sub_vectors() * length; + let pq_code_length = self.pq.code_dim() * length; let pq_code = read_fixed_stride_array( reader.as_ref(), &DataType::UInt8, @@ -287,7 +286,7 @@ impl VectorIndex for PQIndex { .as_ref() .unwrap() .values() - .chunks_exact(self.pq.num_sub_vectors()); + .chunks_exact(self.pq.code_dim()); let row_ids = self.row_ids.as_ref().unwrap().values().iter(); let remapped = row_ids .zip(code) @@ -334,16 +333,16 @@ impl VectorIndex for PQIndex { /// - `metric_type`: The metric type of the vectors. /// - `params`: The parameters to train the PQ model. /// - `ivf`: If provided, the IVF model to compute the residual for PQ training. -pub(super) async fn build_pq_model( +pub async fn build_pq_model( dataset: &Dataset, column: &str, dim: usize, metric_type: MetricType, params: &PQBuildParams, ivf: Option<&IvfModel>, -) -> Result> { +) -> Result { if let Some(codebook) = ¶ms.codebook { - let mt = if metric_type == MetricType::Cosine { + let dt = if metric_type == MetricType::Cosine { info!("Normalize training data for PQ training: Cosine"); MetricType::L2 } else { @@ -351,33 +350,20 @@ pub(super) async fn build_pq_model( }; return match codebook.data_type() { - DataType::Float16 => Ok(Arc::new(ProductQuantizerImpl::::new( + DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(ProductQuantizer::new( params.num_sub_vectors, params.num_bits as u32, dim, - Arc::new(codebook.as_primitive().clone()), - mt, - ))), - DataType::Float32 => Ok(Arc::new(ProductQuantizerImpl::::new( - params.num_sub_vectors, - params.num_bits as u32, - dim, - Arc::new(codebook.as_primitive().clone()), - mt, - ))), - DataType::Float64 => Ok(Arc::new(ProductQuantizerImpl::::new( - params.num_sub_vectors, - params.num_bits as u32, - dim, - Arc::new(codebook.as_primitive().clone()), - mt, - ))), - _ => { - return Err(Error::Index { - message: format!("Wrong codebook data type: {:?}", codebook.data_type()), - location: location!(), - }); - } + FixedSizeListArray::try_new_from_values( + codebook.slice(0, codebook.len()), + dim as i32, + )?, + dt, + )), + _ => Err(Error::Index { + message: format!("Wrong codebook data type: {:?}", codebook.data_type()), + location: location!(), + }), }; } info!( @@ -432,7 +418,7 @@ pub(crate) fn build_pq_storage( distance_type: DistanceType, row_ids: Arc, code_array: Vec>, - pq: Arc, + pq: ProductQuantizer, ) -> Result { let pq_arrs = code_array.iter().map(|a| a.as_ref()).collect::>(); let pq_column = concat(&pq_arrs)?; @@ -443,15 +429,11 @@ pub(crate) fn build_pq_storage( (pq.column(), pq_column, false), ])?; let pq_store = ProductQuantizationStorage::new( - pq.codebook_as_fsl() - .values() - .as_primitive::() - .clone() - .into(), + pq.codebook.clone(), pq_batch.clone(), - pq.num_bits(), - pq.num_sub_vectors(), - pq.dimension(), + pq.num_bits, + pq.code_dim(), + pq.dimension, distance_type, )?; @@ -461,6 +443,7 @@ pub(crate) fn build_pq_storage( mod tests { use super::*; use crate::index::vector::ivf::build_ivf_model; + use arrow::datatypes::Float32Type; use arrow_array::RecordBatchIterator; use arrow_schema::{Field, Schema}; use lance_index::vector::ivf::IvfBuildParams; @@ -473,7 +456,7 @@ mod tests { test_uri: &str, range: Range, ) -> (Dataset, Arc) { - let vectors = generate_random_array_with_range(1000 * DIM, range); + let vectors = generate_random_array_with_range::(1000 * DIM, range); let metadata: HashMap = vec![("test".to_string(), "ivf_pq".to_string())] .into_iter() .collect(); @@ -503,7 +486,7 @@ mod tests { let (dataset, _) = generate_dataset(test_uri, 100.0..120.0).await; - let centroids = generate_random_array_with_range(4 * DIM, -1.0..1.0); + let centroids = generate_random_array_with_range::(4 * DIM, -1.0..1.0); let fsl = FixedSizeListArray::try_new_from_values(centroids, DIM as i32).unwrap(); let ivf = IvfModel::new(fsl); let params = PQBuildParams::new(16, 8); @@ -511,11 +494,11 @@ mod tests { .await .unwrap(); - assert_eq!(pq.num_sub_vectors(), 16); - assert_eq!(pq.num_bits(), 8); - assert_eq!(pq.dimension(), DIM); + assert_eq!(pq.code_dim(), 16); + assert_eq!(pq.num_bits, 8); + assert_eq!(pq.dimension, DIM); - let codebook = pq.codebook_as_fsl(); + let codebook = pq.codebook.clone(); assert_eq!(codebook.len(), 256); codebook .values() @@ -550,11 +533,11 @@ mod tests { .await .unwrap(); - assert_eq!(pq.num_sub_vectors(), 16); - assert_eq!(pq.num_bits(), 8); - assert_eq!(pq.dimension(), DIM); + assert_eq!(pq.code_dim(), 16); + assert_eq!(pq.num_bits, 8); + assert_eq!(pq.dimension, DIM); - let codebook = pq.codebook_as_fsl(); + let codebook = pq.codebook.clone(); assert_eq!(codebook.len(), 256); codebook .values() @@ -575,7 +558,7 @@ mod tests { ); let residual_query = ivf2.compute_residual(&row).unwrap(); - let pq_code = pq.transform(&residual_query).unwrap(); + let pq_code = pq.quantize(&residual_query).unwrap(); let distances = pq .compute_distances( &residual_query.value(0), diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 2dc1c92cd8..661877ed53 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -8,6 +8,7 @@ use arrow_schema::Schema as ArrowSchema; use arrow_select::concat::concat_batches; use futures::stream::TryStreamExt; use snafu::{location, Location}; +use tokio::sync::Mutex; use crate::dataset::Dataset; use crate::{Error, Result}; @@ -65,3 +66,24 @@ pub async fn maybe_sample_training_data( })?; Ok(array.as_fixed_size_list().clone()) } + +#[derive(Debug)] +pub struct PartitionLoadLock { + partition_locks: Vec>>, +} + +impl PartitionLoadLock { + pub fn new(num_partitions: usize) -> Self { + Self { + partition_locks: (0..num_partitions) + .map(|_| Arc::new(Mutex::new(()))) + .collect(), + } + } + + pub fn get_partition_mutex(&self, partition_id: usize) -> Arc> { + let mtx = &self.partition_locks[partition_id]; + + mtx.clone() + } +} diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index f72f0be451..d17c1af0f4 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -306,7 +306,7 @@ pub fn new_knn_exec( let sub_index = ANNIvfSubIndexExec::try_new( Arc::new(ivf_node), dataset, - Arc::new(indices.to_vec()), + indices.to_vec(), query.clone(), prefilter_source, )?; @@ -479,7 +479,7 @@ pub struct ANNIvfSubIndexExec { dataset: Arc, - indices: Arc>, + indices: Vec, /// Vector Query. query: Query, @@ -495,7 +495,7 @@ impl ANNIvfSubIndexExec { pub fn try_new( input: Arc, dataset: Arc, - indices: Arc>, + indices: Vec, query: Query, prefilter_source: PreFilterSource, ) -> Result { diff --git a/rust/lance/src/session.rs b/rust/lance/src/session.rs index d2e781e3e5..aa1163d2fa 100644 --- a/rust/lance/src/session.rs +++ b/rust/lance/src/session.rs @@ -117,12 +117,13 @@ impl Default for Session { mod tests { use super::*; - use arrow_array::types::Float32Type; + use arrow_array::{FixedSizeListArray, Float32Array}; + use lance_arrow::FixedSizeListArrayExt; use std::sync::Arc; use crate::index::vector::pq::PQIndex; - use lance_index::vector::pq::ProductQuantizerImpl; - use lance_linalg::distance::MetricType; + use lance_index::vector::pq::ProductQuantizer; + use lance_linalg::distance::DistanceType; #[test] fn test_disable_index_cache() { @@ -130,14 +131,15 @@ mod tests { assert!(no_cache.index_cache.get_vector("abc").is_none()); let no_cache = Arc::new(no_cache); - let pq = Arc::new(ProductQuantizerImpl::::new( + let pq = ProductQuantizer::new( 1, 8, 1, - Arc::new(vec![0.0f32; 8].into()), - MetricType::L2, - )); - let idx = Arc::new(PQIndex::new(pq, MetricType::L2)); + FixedSizeListArray::try_new_from_values(Float32Array::from(vec![0.0f32; 8]), 1) + .unwrap(), + DistanceType::L2, + ); + let idx = Arc::new(PQIndex::new(pq, DistanceType::L2)); no_cache.index_cache.insert_vector("abc", idx); assert!(no_cache.index_cache.get_vector("abc").is_none()); @@ -149,14 +151,15 @@ mod tests { let session = Session::new(10, 1); let session = Arc::new(session); - let pq = Arc::new(ProductQuantizerImpl::::new( + let pq = ProductQuantizer::new( 1, 8, 1, - Arc::new(vec![0.0f32; 8].into()), - MetricType::L2, - )); - let idx = Arc::new(PQIndex::new(pq, MetricType::L2)); + FixedSizeListArray::try_new_from_values(Float32Array::from(vec![0.0f32; 8]), 1) + .unwrap(), + DistanceType::L2, + ); + let idx = Arc::new(PQIndex::new(pq, DistanceType::L2)); assert_eq!(session.index_cache.get_size(), 0); assert_eq!(session.index_cache.hit_rate(), 1.0); @@ -173,14 +176,15 @@ mod tests { assert_eq!(session.index_cache.get_size(), 1); for iter_idx in 0..100 { - let pq_other = Arc::new(ProductQuantizerImpl::::new( + let pq_other = ProductQuantizer::new( 1, 8, 1, - Arc::new(vec![0.0f32; 8].into()), - MetricType::L2, - )); - let idx_other = Arc::new(PQIndex::new(pq_other, MetricType::L2)); + FixedSizeListArray::try_new_from_values(Float32Array::from(vec![0.0f32; 8]), 1) + .unwrap(), + DistanceType::L2, + ); + let idx_other = Arc::new(PQIndex::new(pq_other, DistanceType::L2)); session .index_cache .insert_vector(format!("{iter_idx}").as_str(), idx_other.clone());