From 0e7a37a844b878285c9e252be6ea2082f4a932c2 Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 8 Jan 2025 20:42:42 +0800 Subject: [PATCH] feat: unify vchordrq and vchordrqfscan Signed-off-by: usamoi --- Cargo.lock | 273 +--- Cargo.toml | 12 +- crates/algorithm/Cargo.toml | 7 - crates/rabitq/src/binary.rs | 72 +- crates/rabitq/src/block.rs | 160 +-- crates/rabitq/src/lib.rs | 94 +- crates/rabitq/src/utils.rs | 20 - crates/simd/src/f16.rs | 4 +- crates/simd/src/fast_scan/mod.rs | 306 +++-- src/{vchordrq => }/algorithm/build.rs | 139 +- src/algorithm/freepages.rs | 61 + src/algorithm/insert.rs | 190 +++ .../src/lib.rs => src/algorithm/mod.rs | 13 +- src/algorithm/operator.rs | 628 +++++++++ src/algorithm/prewarm.rs | 99 ++ src/algorithm/scan.rs | 171 +++ src/algorithm/tape.rs | 349 +++++ src/algorithm/tuples.rs | 1205 +++++++++++++++++ src/algorithm/vacuum.rs | 311 +++++ src/algorithm/vectors.rs | 133 ++ src/{vchordrq => }/gucs/executing.rs | 0 src/{vchordrq => }/gucs/mod.rs | 0 src/{vchordrq => }/gucs/prewarm.rs | 0 src/{vchordrq => }/index/am.rs | 448 ++++-- src/{vchordrq => }/index/am_options.rs | 6 +- src/{vchordrq => }/index/am_scan.rs | 64 +- src/{vchordrq => }/index/functions.rs | 22 +- src/{vchordrq => }/index/mod.rs | 0 src/{vchordrq => }/index/opclass.rs | 0 src/index/utils.rs | 34 + src/lib.rs | 11 +- src/postgres.rs | 55 +- src/sql/finalize.sql | 30 - src/{vchordrq => }/types.rs | 0 src/utils/k_means.rs | 76 +- src/utils/mod.rs | 1 + src/utils/pipe.rs | 14 + src/vchordrq/algorithm/insert.rs | 221 --- src/vchordrq/algorithm/mod.rs | 8 - src/vchordrq/algorithm/prewarm.rs | 82 -- src/vchordrq/algorithm/rabitq.rs | 21 - src/vchordrq/algorithm/scan.rs | 189 --- src/vchordrq/algorithm/tuples.rs | 270 ---- src/vchordrq/algorithm/vacuum.rs | 112 -- src/vchordrq/algorithm/vectors.rs | 76 -- src/vchordrq/index/utils.rs | 20 - src/vchordrq/mod.rs | 11 - src/vchordrqfscan/algorithm/build.rs | 445 ------ src/vchordrqfscan/algorithm/insert.rs | 299 ---- src/vchordrqfscan/algorithm/mod.rs | 7 - src/vchordrqfscan/algorithm/prewarm.rs | 102 -- src/vchordrqfscan/algorithm/rabitq.rs | 22 - src/vchordrqfscan/algorithm/scan.rs | 193 --- src/vchordrqfscan/algorithm/tuples.rs | 140 -- src/vchordrqfscan/algorithm/vacuum.rs | 139 -- src/vchordrqfscan/gucs/executing.rs | 72 - src/vchordrqfscan/gucs/mod.rs | 14 - src/vchordrqfscan/gucs/prewarm.rs | 32 - src/vchordrqfscan/index/am.rs | 856 ------------ src/vchordrqfscan/index/am_options.rs | 218 --- src/vchordrqfscan/index/am_scan.rs | 125 -- src/vchordrqfscan/index/functions.rs | 26 - src/vchordrqfscan/index/mod.rs | 12 - src/vchordrqfscan/index/opclass.rs | 14 - src/vchordrqfscan/index/utils.rs | 20 - src/vchordrqfscan/mod.rs | 11 - src/vchordrqfscan/types.rs | 153 --- 67 files changed, 4144 insertions(+), 4774 deletions(-) delete mode 100644 crates/algorithm/Cargo.toml delete mode 100644 crates/rabitq/src/utils.rs rename src/{vchordrq => }/algorithm/build.rs (78%) create mode 100644 src/algorithm/freepages.rs create mode 100644 src/algorithm/insert.rs rename crates/algorithm/src/lib.rs => src/algorithm/mod.rs (85%) create mode 100644 src/algorithm/operator.rs create mode 100644 src/algorithm/prewarm.rs create mode 100644 src/algorithm/scan.rs create mode 100644 src/algorithm/tape.rs create mode 100644 src/algorithm/tuples.rs create mode 100644 src/algorithm/vacuum.rs create mode 100644 src/algorithm/vectors.rs rename src/{vchordrq => }/gucs/executing.rs (100%) rename src/{vchordrq => }/gucs/mod.rs (100%) rename src/{vchordrq => }/gucs/prewarm.rs (100%) rename src/{vchordrq => }/index/am.rs (67%) rename src/{vchordrq => }/index/am_options.rs (97%) rename src/{vchordrq => }/index/am_scan.rs (64%) rename src/{vchordrq => }/index/functions.rs (58%) rename src/{vchordrq => }/index/mod.rs (100%) rename src/{vchordrq => }/index/opclass.rs (100%) create mode 100644 src/index/utils.rs rename src/{vchordrq => }/types.rs (100%) create mode 100644 src/utils/pipe.rs delete mode 100644 src/vchordrq/algorithm/insert.rs delete mode 100644 src/vchordrq/algorithm/mod.rs delete mode 100644 src/vchordrq/algorithm/prewarm.rs delete mode 100644 src/vchordrq/algorithm/rabitq.rs delete mode 100644 src/vchordrq/algorithm/scan.rs delete mode 100644 src/vchordrq/algorithm/tuples.rs delete mode 100644 src/vchordrq/algorithm/vacuum.rs delete mode 100644 src/vchordrq/algorithm/vectors.rs delete mode 100644 src/vchordrq/index/utils.rs delete mode 100644 src/vchordrq/mod.rs delete mode 100644 src/vchordrqfscan/algorithm/build.rs delete mode 100644 src/vchordrqfscan/algorithm/insert.rs delete mode 100644 src/vchordrqfscan/algorithm/mod.rs delete mode 100644 src/vchordrqfscan/algorithm/prewarm.rs delete mode 100644 src/vchordrqfscan/algorithm/rabitq.rs delete mode 100644 src/vchordrqfscan/algorithm/scan.rs delete mode 100644 src/vchordrqfscan/algorithm/tuples.rs delete mode 100644 src/vchordrqfscan/algorithm/vacuum.rs delete mode 100644 src/vchordrqfscan/gucs/executing.rs delete mode 100644 src/vchordrqfscan/gucs/mod.rs delete mode 100644 src/vchordrqfscan/gucs/prewarm.rs delete mode 100644 src/vchordrqfscan/index/am.rs delete mode 100644 src/vchordrqfscan/index/am_options.rs delete mode 100644 src/vchordrqfscan/index/am_scan.rs delete mode 100644 src/vchordrqfscan/index/functions.rs delete mode 100644 src/vchordrqfscan/index/mod.rs delete mode 100644 src/vchordrqfscan/index/opclass.rs delete mode 100644 src/vchordrqfscan/index/utils.rs delete mode 100644 src/vchordrqfscan/mod.rs delete mode 100644 src/vchordrqfscan/types.rs diff --git a/Cargo.lock b/Cargo.lock index 5d17e1a..f5214d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,17 +2,6 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "ahash" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -22,10 +11,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "algorithm" -version = "0.0.0" - [[package]] name = "always_equal" version = "0.0.0" @@ -87,14 +72,14 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.94", + "syn", ] [[package]] name = "bitflags" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" [[package]] name = "bitvec" @@ -108,28 +93,6 @@ dependencies = [ "wyz", ] -[[package]] -name = "bytecheck" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23cdc57ce23ac53c931e88a43d06d070a6fd142f2617be5855eb75efc9beb1c2" -dependencies = [ - "bytecheck_derive", - "ptr_meta", - "simdutf8", -] - -[[package]] -name = "bytecheck_derive" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3db406d29fbcd95542e92559bed4d8ad92636d1ca8b3b72ede10b4bcc010e659" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "bytemuck" version = "1.21.0" @@ -142,12 +105,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" -[[package]] -name = "bytes" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" - [[package]] name = "cargo_toml" version = "0.19.2" @@ -160,9 +117,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.6" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d6dbb628b8f8555f86d0323c2eb39e3ec81901f4b83e091db8a6a76d316a333" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "shlex", ] @@ -264,7 +221,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.94", + "syn", ] [[package]] @@ -275,7 +232,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -286,7 +243,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -316,7 +273,7 @@ checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -388,12 +345,13 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" [[package]] name = "half" version = "2.4.1" -source = "git+https://github.com/tensorchord/half-rs.git#2d8a66092bee436aebb26ea7ac47d11150cda31d" +source = "git+https://github.com/tensorchord/half-rs.git?rev=3f9a8843d6722bd1833de2289347640ad8770146#3f9a8843d6722bd1833de2289347640ad8770146" dependencies = [ "cfg-if", "crunchy", - "rkyv", "serde", + "zerocopy 0.8.14", + "zerocopy-derive 0.8.14", ] [[package]] @@ -405,15 +363,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -dependencies = [ - "ahash", -] - [[package]] name = "hashbrown" version = "0.15.2" @@ -560,7 +509,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -603,7 +552,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown", ] [[package]] @@ -668,9 +617,9 @@ checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "matrixmultiply" @@ -718,7 +667,7 @@ checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -825,7 +774,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" dependencies = [ "memchr", - "thiserror 2.0.9", + "thiserror 2.0.11", "ucd-trie", ] @@ -877,7 +826,7 @@ dependencies = [ "proc-macro2", "quote", "shlex", - "syn 2.0.94", + "syn", "walkdir", ] @@ -900,7 +849,7 @@ dependencies = [ "pgrx-sql-entity-graph", "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -947,7 +896,7 @@ dependencies = [ "petgraph", "proc-macro2", "quote", - "syn 2.0.94", + "syn", "thiserror 1.0.69", "unescape", ] @@ -958,7 +907,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -980,38 +929,18 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] -[[package]] -name = "ptr_meta" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0738ccf7ea06b608c10564b31debd4f5bc5e197fc8bfe088f68ae5ce81e7a4f1" -dependencies = [ - "ptr_meta_derive", -] - -[[package]] -name = "ptr_meta_derive" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "quote" version = "1.0.38" @@ -1140,44 +1069,6 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" -[[package]] -name = "rend" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71fe3824f5629716b1589be05dacd749f6aa084c87e00e016714a8cdfccc997c" -dependencies = [ - "bytecheck", -] - -[[package]] -name = "rkyv" -version = "0.7.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9008cd6385b9e161d8229e1f6549dd23c3d022f132a2ea37ac3a10ac4935779b" -dependencies = [ - "bitvec", - "bytecheck", - "bytes", - "hashbrown 0.12.3", - "ptr_meta", - "rend", - "rkyv_derive", - "seahash", - "tinyvec", - "uuid", -] - -[[package]] -name = "rkyv_derive" -version = "0.7.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "503d1d27590a2b0a3a4ca4c94755aa2875657196ecbf401a42eff41d7de532c0" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "rustc-hash" version = "1.1.0" @@ -1268,14 +1159,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] name = "serde_json" -version = "1.0.134" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -1328,15 +1219,9 @@ version = "0.0.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] -[[package]] -name = "simdutf8" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" - [[package]] name = "smallvec" version = "1.13.2" @@ -1382,20 +1267,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.109" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "syn" -version = "2.0.94" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -1410,7 +1284,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -1430,11 +1304,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.9" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" dependencies = [ - "thiserror-impl 2.0.9", + "thiserror-impl 2.0.11", ] [[package]] @@ -1445,18 +1319,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] name = "thiserror-impl" -version = "2.0.9" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] @@ -1469,21 +1343,6 @@ dependencies = [ "zerovec", ] -[[package]] -name = "tinyvec" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - [[package]] name = "toml" version = "0.8.19" @@ -1579,9 +1438,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" dependencies = [ "getrandom", ] @@ -1613,14 +1472,13 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] [[package]] name = "vchord" version = "0.0.0" dependencies = [ - "algorithm", "always_equal", "distance", "half 2.4.1", @@ -1632,12 +1490,13 @@ dependencies = [ "rand", "random_orthogonal_matrix", "rayon", - "rkyv", "serde", "simd", "toml", "validator", "vector", + "zerocopy 0.8.14", + "zerocopy-derive 0.8.14", ] [[package]] @@ -1650,12 +1509,6 @@ dependencies = [ "simd", ] -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - [[package]] name = "walkdir" version = "2.5.0" @@ -1674,9 +1527,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wide" -version = "0.7.30" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58e6db2670d2be78525979e9a5f9c69d296fd7d670549fe9ebf70f8708cb5019" +checksum = "41b5576b9a81633f3e8df296ce0063042a73507636cbe956c61133dd7034ab22" dependencies = [ "bytemuck", "safe_arch", @@ -1797,9 +1650,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.21" +version = "0.6.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6f5bb5257f2407a5425c6e749bfd9692192a73e70a6060516ac04f889087d68" +checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" dependencies = [ "memchr", ] @@ -1854,7 +1707,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", "synstructure", ] @@ -1865,7 +1718,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a367f292d93d4eab890745e75a778da40909cab4d6ff8173693812f79c4a2468" +dependencies = [ + "zerocopy-derive 0.8.14", ] [[package]] @@ -1876,7 +1738,18 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3931cb58c62c13adec22e38686b559c86a30565e16ad6e8510a337cedc611e1" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -1896,7 +1769,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", "synstructure", ] @@ -1919,5 +1792,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.94", + "syn", ] diff --git a/Cargo.toml b/Cargo.toml index 13272c2..3583ae2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ pg16 = ["pgrx/pg16", "pgrx-catalog/pg16"] pg17 = ["pgrx/pg17", "pgrx-catalog/pg17"] [dependencies] -algorithm = { path = "./crates/algorithm" } always_equal = { path = "./crates/always_equal" } distance = { path = "./crates/distance" } rabitq = { path = "./crates/rabitq" } @@ -28,11 +27,8 @@ random_orthogonal_matrix = { path = "./crates/random_orthogonal_matrix" } simd = { path = "./crates/simd" } vector = { path = "./crates/vector" } -# lock rkyv version forever so that data is always compatible -rkyv = { version = "=0.7.45", features = ["validation"] } - half.workspace = true -log = "0.4.22" +log = "0.4.25" paste = "1" pgrx = { version = "=0.12.9", default-features = false, features = ["cshim"] } pgrx-catalog = "0.1.0" @@ -41,9 +37,11 @@ rayon = "1.10.0" serde.workspace = true toml = "0.8.19" validator = { version = "0.19.0", features = ["derive"] } +zerocopy = "0.8.14" +zerocopy-derive = "0.8.14" [patch.crates-io] -half = { git = "https://github.com/tensorchord/half-rs.git" } +half = { git = "https://github.com/tensorchord/half-rs.git", rev = "3f9a8843d6722bd1833de2289347640ad8770146" } [lints] workspace = true @@ -57,7 +55,7 @@ version = "0.0.0" edition = "2021" [workspace.dependencies] -half = { version = "2.4.1", features = ["rkyv", "serde"] } +half = { version = "2.4.1", features = ["serde", "zerocopy"] } rand = "0.8.5" serde = "1" diff --git a/crates/algorithm/Cargo.toml b/crates/algorithm/Cargo.toml deleted file mode 100644 index 56da6ce..0000000 --- a/crates/algorithm/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "algorithm" -version.workspace = true -edition.workspace = true - -[lints] -workspace = true diff --git a/crates/rabitq/src/binary.rs b/crates/rabitq/src/binary.rs index 8916ffd..04bcb6f 100644 --- a/crates/rabitq/src/binary.rs +++ b/crates/rabitq/src/binary.rs @@ -1,67 +1,9 @@ use distance::Distance; use simd::Floating; -#[derive(Debug, Clone)] -pub struct Code { - pub dis_u_2: f32, - pub factor_ppc: f32, - pub factor_ip: f32, - pub factor_err: f32, - pub signs: Vec, -} - -impl Code { - pub fn t(&self) -> Vec { - use crate::utils::InfiniteByteChunks; - let mut result = Vec::new(); - for x in InfiniteByteChunks::<_, 64>::new(self.signs.iter().copied()) - .take(self.signs.len().div_ceil(64)) - { - let mut r = 0_u64; - for i in 0..64 { - r |= (x[i] as u64) << i; - } - result.push(r); - } - result - } -} - -pub fn code(dims: u32, vector: &[f32]) -> Code { - let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); - let sum_of_x_2 = f32::reduce_sum_of_x2(vector); - let dis_u = sum_of_x_2.sqrt(); - let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt(); - let x_x0 = dis_u / x0; - let fac_norm = (dims as f32).sqrt(); - let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); - let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); - let factor_ip = -2.0f32 / fac_norm * x_x0; - let cnt_pos = vector - .iter() - .map(|x| x.is_sign_positive() as i32) - .sum::(); - let cnt_neg = vector - .iter() - .map(|x| x.is_sign_negative() as i32) - .sum::(); - let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; - let mut signs = Vec::new(); - for i in 0..dims { - signs.push(vector[i as usize].is_sign_positive() as u8); - } - Code { - dis_u_2: sum_of_x_2, - factor_ppc, - factor_ip, - factor_err, - signs, - } -} - -pub type Lut = (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)); - -pub fn preprocess(vector: &[f32]) -> Lut { +pub fn preprocess( + vector: &[f32], +) -> (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)) { let dis_v_2 = f32::reduce_sum_of_x2(vector); let (k, b, qvector) = simd::quantize::quantize(vector, 15.0); let qvector_sum = if vector.len() <= 4369 { @@ -73,8 +15,7 @@ pub fn preprocess(vector: &[f32]) -> Lut { } pub fn process_lowerbound_l2( - _: u32, - lut: &Lut, + lut: &(f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), (dis_u_2, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]), epsilon: f32, ) -> Distance { @@ -87,8 +28,7 @@ pub fn process_lowerbound_l2( } pub fn process_lowerbound_dot( - _: u32, - lut: &Lut, + lut: &(f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), (_, factor_ppc, factor_ip, factor_err, t): (f32, f32, f32, f32, &[u64]), epsilon: f32, ) -> Distance { @@ -99,7 +39,7 @@ pub fn process_lowerbound_dot( Distance::from_f32(rough - epsilon * err) } -fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { +pub fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { let n = vector.len(); let mut t0 = vec![0u64; n.div_ceil(64)]; let mut t1 = vec![0u64; n.div_ceil(64)]; diff --git a/crates/rabitq/src/block.rs b/crates/rabitq/src/block.rs index 0a8dab7..9f26fce 100644 --- a/crates/rabitq/src/block.rs +++ b/crates/rabitq/src/block.rs @@ -1,85 +1,7 @@ use distance::Distance; use simd::Floating; -#[derive(Debug, Clone)] -pub struct Code { - pub dis_u_2: f32, - pub factor_ppc: f32, - pub factor_ip: f32, - pub factor_err: f32, - pub signs: Vec, -} - -pub fn code(dims: u32, vector: &[f32]) -> Code { - let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); - let sum_of_x_2 = f32::reduce_sum_of_x2(vector); - let dis_u = sum_of_x_2.sqrt(); - let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt(); - let x_x0 = dis_u / x0; - let fac_norm = (dims as f32).sqrt(); - let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); - let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); - let factor_ip = -2.0f32 / fac_norm * x_x0; - let cnt_pos = vector - .iter() - .map(|x| x.is_sign_positive() as i32) - .sum::(); - let cnt_neg = vector - .iter() - .map(|x| x.is_sign_negative() as i32) - .sum::(); - let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; - let mut signs = Vec::new(); - for i in 0..dims { - signs.push(vector[i as usize].is_sign_positive() as u8); - } - Code { - dis_u_2: sum_of_x_2, - factor_ppc, - factor_ip, - factor_err, - signs, - } -} - -pub fn dummy_code(dims: u32) -> Code { - Code { - dis_u_2: 0.0, - factor_ppc: 0.0, - factor_ip: 0.0, - factor_err: 0.0, - signs: vec![0; dims as _], - } -} - -pub struct PackedCodes { - pub dis_u_2: [f32; 32], - pub factor_ppc: [f32; 32], - pub factor_ip: [f32; 32], - pub factor_err: [f32; 32], - pub t: Vec, -} - -pub fn pack_codes(dims: u32, codes: [Code; 32]) -> PackedCodes { - use crate::utils::InfiniteByteChunks; - PackedCodes { - dis_u_2: std::array::from_fn(|i| codes[i].dis_u_2), - factor_ppc: std::array::from_fn(|i| codes[i].factor_ppc), - factor_ip: std::array::from_fn(|i| codes[i].factor_ip), - factor_err: std::array::from_fn(|i| codes[i].factor_err), - t: { - let signs = codes.map(|code| { - InfiniteByteChunks::new(code.signs.into_iter()) - .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) - .take(dims.div_ceil(4) as usize) - .collect::>() - }); - simd::fast_scan::pack(dims.div_ceil(4), signs).collect() - }, - } -} - -pub fn fscan_preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec) { +pub fn preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec<[u64; 2]>) { let dis_v_2 = f32::reduce_sum_of_x2(vector); let (k, b, qvector) = simd::quantize::quantize(vector, 15.0); let qvector_sum = if vector.len() <= 4369 { @@ -90,20 +12,19 @@ pub fn fscan_preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec) { (dis_v_2, b, k, qvector_sum, compress(qvector)) } -pub fn fscan_process_lowerbound_l2( - dims: u32, - lut: &(f32, f32, f32, f32, Vec), +pub fn process_lowerbound_l2( + lut: &(f32, f32, f32, f32, Vec<[u64; 2]>), (dis_u_2, factor_ppc, factor_ip, factor_err, t): ( &[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], - &[u8], + &[[u64; 2]], ), epsilon: f32, ) -> [Distance; 32] { let &(dis_v_2, b, k, qvector_sum, ref s) = lut; - let r = simd::fast_scan::fast_scan(dims.div_ceil(4), t, s); + let r = simd::fast_scan::fast_scan(t, s); std::array::from_fn(|i| { let rough = dis_u_2[i] + dis_v_2 @@ -114,20 +35,19 @@ pub fn fscan_process_lowerbound_l2( }) } -pub fn fscan_process_lowerbound_dot( - dims: u32, - lut: &(f32, f32, f32, f32, Vec), +pub fn process_lowerbound_dot( + lut: &(f32, f32, f32, f32, Vec<[u64; 2]>), (_, factor_ppc, factor_ip, factor_err, t): ( &[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], - &[u8], + &[[u64; 2]], ), epsilon: f32, ) -> [Distance; 32] { let &(dis_v_2, b, k, qvector_sum, ref s) = lut; - let r = simd::fast_scan::fast_scan(dims.div_ceil(4), t, s); + let r = simd::fast_scan::fast_scan(t, s); std::array::from_fn(|i| { let rough = 0.5 * b * factor_ppc[i] + 0.5 * ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; @@ -136,37 +56,41 @@ pub fn fscan_process_lowerbound_dot( }) } -fn compress(mut qvector: Vec) -> Vec { - let dims = qvector.len() as u32; - let width = dims.div_ceil(4); - qvector.resize(qvector.len().next_multiple_of(4), 0); - let mut t = vec![0u8; width as usize * 16]; - for i in 0..width as usize { +pub fn compress(mut vector: Vec) -> Vec<[u64; 2]> { + let width = vector.len().div_ceil(4); + vector.resize(width * 4, 0); + let mut result = vec![[0u64, 0u64]; width]; + for i in 0..width { unsafe { // this hint is used to skip bound checks - std::hint::assert_unchecked(4 * i + 3 < qvector.len()); - std::hint::assert_unchecked(16 * i + 15 < t.len()); + std::hint::assert_unchecked(4 * i + 3 < vector.len()); } - let t0 = qvector[4 * i + 0]; - let t1 = qvector[4 * i + 1]; - let t2 = qvector[4 * i + 2]; - let t3 = qvector[4 * i + 3]; - t[16 * i + 0b0000] = 0; - t[16 * i + 0b0001] = t0; - t[16 * i + 0b0010] = t1; - t[16 * i + 0b0011] = t1 + t0; - t[16 * i + 0b0100] = t2; - t[16 * i + 0b0101] = t2 + t0; - t[16 * i + 0b0110] = t2 + t1; - t[16 * i + 0b0111] = t2 + t1 + t0; - t[16 * i + 0b1000] = t3; - t[16 * i + 0b1001] = t3 + t0; - t[16 * i + 0b1010] = t3 + t1; - t[16 * i + 0b1011] = t3 + t1 + t0; - t[16 * i + 0b1100] = t3 + t2; - t[16 * i + 0b1101] = t3 + t2 + t0; - t[16 * i + 0b1110] = t3 + t2 + t1; - t[16 * i + 0b1111] = t3 + t2 + t1 + t0; + let t_0 = vector[4 * i + 0]; + let t_1 = vector[4 * i + 1]; + let t_2 = vector[4 * i + 2]; + let t_3 = vector[4 * i + 3]; + result[i] = [ + u64::from_le_bytes([ + 0, + t_0, + t_1, + t_1 + t_0, + t_2, + t_2 + t_0, + t_2 + t_1, + t_2 + t_1 + t_0, + ]), + u64::from_le_bytes([ + t_3, + t_3 + t_0, + t_3 + t_1, + t_3 + t_1 + t_0, + t_3 + t_2, + t_3 + t_2 + t_0, + t_3 + t_2 + t_1, + t_3 + t_2 + t_1 + t_0, + ]), + ]; } - t + result } diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs index 0796b7a..1e379a3 100644 --- a/crates/rabitq/src/lib.rs +++ b/crates/rabitq/src/lib.rs @@ -2,4 +2,96 @@ pub mod binary; pub mod block; -mod utils; + +use simd::Floating; + +#[derive(Debug, Clone)] +pub struct Code { + pub dis_u_2: f32, + pub factor_ppc: f32, + pub factor_ip: f32, + pub factor_err: f32, + pub signs: Vec, +} + +pub fn pack_to_u4(signs: &[bool]) -> Vec { + fn f(x: [bool; 4]) -> u8 { + x[0] as u8 | (x[1] as u8) << 1 | (x[2] as u8) << 2 | (x[3] as u8) << 3 + } + let mut result = Vec::with_capacity(signs.len().div_ceil(4)); + for i in 0..signs.len().div_ceil(4) { + let x = std::array::from_fn(|j| signs.get(i * 4 + j).copied().unwrap_or_default()); + result.push(f(x)); + } + result +} + +pub fn pack_to_u64(signs: &[bool]) -> Vec { + fn f(x: [bool; 64]) -> u64 { + let mut result = 0_u64; + for i in 0..64 { + result |= (x[i] as u64) << i; + } + result + } + let mut result = Vec::with_capacity(signs.len().div_ceil(64)); + for i in 0..signs.len().div_ceil(64) { + let x = std::array::from_fn(|j| signs.get(i * 64 + j).copied().unwrap_or_default()); + result.push(f(x)); + } + result +} + +pub fn code(dims: u32, vector: &[f32]) -> Code { + let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); + let sum_of_x_2 = f32::reduce_sum_of_x2(vector); + let dis_u = sum_of_x_2.sqrt(); + let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = (dims as f32).sqrt(); + let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); + let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = -2.0f32 / fac_norm * x_x0; + let cnt_pos = vector + .iter() + .map(|x| x.is_sign_positive() as i32) + .sum::(); + let cnt_neg = vector + .iter() + .map(|x| x.is_sign_negative() as i32) + .sum::(); + let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; + let mut signs = Vec::new(); + for i in 0..dims { + signs.push(vector[i as usize].is_sign_positive()); + } + Code { + dis_u_2: sum_of_x_2, + factor_ppc, + factor_ip, + factor_err, + signs, + } +} + +pub fn compute_lut( + vector: &[f32], +) -> ( + (f32, f32, f32, f32, Vec<[u64; 2]>), + (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), +) { + use simd::Floating; + let dis_v_2 = f32::reduce_sum_of_x2(vector); + let (k, b, qvector) = simd::quantize::quantize(vector, 15.0); + let qvector_sum = if vector.len() <= 4369 { + simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + simd::u8::reduce_sum_of_x(&qvector) as f32 + }; + let binary = binary::binarize(&qvector); + let block = block::compress(qvector); + ( + (dis_v_2, b, k, qvector_sum, block), + (dis_v_2, b, k, qvector_sum, binary), + ) +} diff --git a/crates/rabitq/src/utils.rs b/crates/rabitq/src/utils.rs deleted file mode 100644 index c61b87e..0000000 --- a/crates/rabitq/src/utils.rs +++ /dev/null @@ -1,20 +0,0 @@ -#[derive(Debug, Clone)] -pub struct InfiniteByteChunks { - iter: I, -} - -impl InfiniteByteChunks { - pub fn new(iter: I) -> Self { - Self { iter } - } -} - -impl, const N: usize> Iterator for InfiniteByteChunks { - type Item = [u8; N]; - - fn next(&mut self) -> Option { - Some(std::array::from_fn::(|_| { - self.iter.next().unwrap_or(0) - })) - } -} diff --git a/crates/simd/src/f16.rs b/crates/simd/src/f16.rs index 53bddb7..ce90618 100644 --- a/crates/simd/src/f16.rs +++ b/crates/simd/src/f16.rs @@ -517,7 +517,7 @@ mod reduce_sum_of_xy { } // temporarily disables this for uncertain precision - #[expect(dead_code)] + #[cfg_attr(not(test), expect(dead_code))] #[inline] #[cfg(target_arch = "aarch64")] #[crate::target_cpu(enable = "v8.3a")] @@ -894,7 +894,7 @@ mod reduce_sum_of_d2 { } // temporarily disables this for uncertain precision - #[expect(dead_code)] + #[cfg_attr(not(test), expect(dead_code))] #[inline] #[cfg(target_arch = "aarch64")] #[crate::target_cpu(enable = "v8.3a")] diff --git a/crates/simd/src/fast_scan/mod.rs b/crates/simd/src/fast_scan/mod.rs index c2579af..84ad990 100644 --- a/crates/simd/src/fast_scan/mod.rs +++ b/crates/simd/src/fast_scan/mod.rs @@ -1,16 +1,16 @@ /* -## codes layout for 4-bit quantizer +## code layout for 4-bit quantizer -group i = | vector i | (total bytes = width/2) +group i = | vector i | (total bytes = n/2) -byte: | 0 | 1 | 2 | ... | width/2 - 1 | -bits 0..3: | code 0 | code 2 | code 4 | ... | code width-2 | -bits 4..7: | code 1 | code 3 | code 5 | ... | code width-1 | +byte: | 0 | 1 | 2 | ... | n/2 - 1 | +bits 0..3: | code 0 | code 2 | code 4 | ... | code n-2 | +bits 4..7: | code 1 | code 3 | code 5 | ... | code n-1 | -## packed_codes layout for 4-bit quantizer +## packed_code layout for 4-bit quantizer -group i = | vector 32i | vector 32i+1 | vector 32i+2 | ... | vector 32i+31 | (total bytes = width * 16) +group i = | vector 32i | vector 32i+1 | vector 32i+2 | ... | vector 32i+31 | (total bytes = n * 16) byte | 0 | 1 | 2 | ... | 14 | 15 | bits 0..3 | code 0,vector 0 | code 0,vector 8 | code 0,vector 1 | ... | code 0,vector 14 | code 0,vector 15 | @@ -26,34 +26,105 @@ bits 4..7 | code 2,vector 16 | code 2,vector 24 | code 2,vector 17 | ... | code ... -byte | width*32-32 | width*32-31 | ... | width*32-1 | -bits 0..3 | code (width-1),vector 0 | code (width-1),vector 8 | ... | code (width-1),vector 15 | -bits 4..7 | code (width-1),vector 16 | code (width-1),vector 24 | ... | code (width-1),vector 31 | +byte | n*32-32 | n*32-31 | ... | n*32-1 | +bits 0..3 | code (n-1),vector 0 | code (n-1),vector 8 | ... | code (n-1),vector 15 | +bits 4..7 | code (n-1),vector 16 | code (n-1),vector 24 | ... | code (n-1),vector 31 | */ -pub fn pack(width: u32, r: [Vec; 32]) -> impl Iterator { - (0..width as usize).flat_map(move |i| { - [ - r[0][i] | (r[16][i] << 4), - r[8][i] | (r[24][i] << 4), - r[1][i] | (r[17][i] << 4), - r[9][i] | (r[25][i] << 4), - r[2][i] | (r[18][i] << 4), - r[10][i] | (r[26][i] << 4), - r[3][i] | (r[19][i] << 4), - r[11][i] | (r[27][i] << 4), - r[4][i] | (r[20][i] << 4), - r[12][i] | (r[28][i] << 4), - r[5][i] | (r[21][i] << 4), - r[13][i] | (r[29][i] << 4), - r[6][i] | (r[22][i] << 4), - r[14][i] | (r[30][i] << 4), - r[7][i] | (r[23][i] << 4), - r[15][i] | (r[31][i] << 4), - ] - .into_iter() - }) +pub fn pack(x: [&[u8]; 32]) -> Vec<[u64; 2]> { + let n = { + let l = x.each_ref().map(|i| i.len()); + for i in 1..32 { + assert!(l[0] == l[i]); + } + l[0] + }; + let mut result = Vec::with_capacity(n); + for i in 0..n { + result.push([ + u64::from_le_bytes([ + x[0][i] | (x[16][i] << 4), + x[8][i] | (x[24][i] << 4), + x[1][i] | (x[17][i] << 4), + x[9][i] | (x[25][i] << 4), + x[2][i] | (x[18][i] << 4), + x[10][i] | (x[26][i] << 4), + x[3][i] | (x[19][i] << 4), + x[11][i] | (x[27][i] << 4), + ]), + u64::from_le_bytes([ + x[4][i] | (x[20][i] << 4), + x[12][i] | (x[28][i] << 4), + x[5][i] | (x[21][i] << 4), + x[13][i] | (x[29][i] << 4), + x[6][i] | (x[22][i] << 4), + x[14][i] | (x[30][i] << 4), + x[7][i] | (x[23][i] << 4), + x[15][i] | (x[31][i] << 4), + ]), + ]); + } + result +} + +pub fn unpack(x: &[[u64; 2]]) -> [Vec; 32] { + let n = x.len(); + let mut result = std::array::from_fn(|_| Vec::with_capacity(n)); + for i in 0..n { + let a = x[i][0].to_le_bytes(); + let b = x[i][1].to_le_bytes(); + result[0].push(a[0] & 0xf); + result[1].push(a[2] & 0xf); + result[2].push(a[4] & 0xf); + result[3].push(a[6] & 0xf); + result[4].push(b[0] & 0xf); + result[5].push(b[2] & 0xf); + result[6].push(b[4] & 0xf); + result[7].push(b[6] & 0xf); + result[8].push(a[1] & 0xf); + result[9].push(a[3] & 0xf); + result[10].push(a[5] & 0xf); + result[11].push(a[7] & 0xf); + result[12].push(b[1] & 0xf); + result[13].push(b[3] & 0xf); + result[14].push(b[5] & 0xf); + result[15].push(b[7] & 0xf); + result[16].push(a[0] >> 4); + result[17].push(a[2] >> 4); + result[18].push(a[4] >> 4); + result[19].push(a[6] >> 4); + result[20].push(b[0] >> 4); + result[21].push(b[2] >> 4); + result[22].push(b[4] >> 4); + result[23].push(b[6] >> 4); + result[24].push(a[1] >> 4); + result[25].push(a[3] >> 4); + result[26].push(a[5] >> 4); + result[27].push(a[7] >> 4); + result[28].push(b[1] >> 4); + result[29].push(b[3] >> 4); + result[30].push(b[5] >> 4); + result[31].push(b[7] >> 4); + } + result +} + +pub fn padding_pack(x: impl IntoIterator>) -> Vec<[u64; 2]> { + let x = x.into_iter().collect::>(); + let x = x.iter().map(|x| x.as_ref()).collect::>(); + if x.is_empty() || x.len() > 32 { + panic!("too few or too many slices"); + } + let n = x[0].len(); + let t = vec![0; n]; + pack(std::array::from_fn(|i| { + if i < x.len() { x[i] } else { t.as_slice() } + })) +} + +pub fn any_pack(mut x: impl Iterator) -> [T; 32] { + std::array::from_fn(|_| x.next()).map(|x| x.unwrap_or_default()) } #[allow(clippy::module_inception)] @@ -61,10 +132,10 @@ mod fast_scan { #[inline] #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v4")] - fn fast_scan_v4(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + fn fast_scan_v4(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually - assert_eq!(codes.len(), width as usize * 16); - assert_eq!(lut.len(), width as usize * 16); + assert_eq!(code.len(), lut.len()); + let n = code.len(); unsafe { use std::arch::x86_64::*; @@ -99,14 +170,14 @@ mod fast_scan { let mut accu_3 = _mm512_setzero_si512(); let mut i = 0_usize; - while i + 4 <= width as usize { - let c = _mm512_loadu_si512(codes.as_ptr().add(i * 16).cast()); + while i + 4 <= n { + let c = _mm512_loadu_si512(code.as_ptr().add(i).cast()); let mask = _mm512_set1_epi8(0xf); let clo = _mm512_and_si512(c, mask); let chi = _mm512_and_si512(_mm512_srli_epi16(c, 4), mask); - let lut = _mm512_loadu_si512(lut.as_ptr().add(i * 16).cast()); + let lut = _mm512_loadu_si512(lut.as_ptr().add(i).cast()); let res_lo = _mm512_shuffle_epi8(lut, clo); accu_0 = _mm512_add_epi16(accu_0, res_lo); accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); @@ -116,14 +187,14 @@ mod fast_scan { i += 4; } - if i + 2 <= width as usize { - let c = _mm256_loadu_si256(codes.as_ptr().add(i * 16).cast()); + if i + 2 <= n { + let c = _mm256_loadu_si256(code.as_ptr().add(i).cast()); let mask = _mm256_set1_epi8(0xf); let clo = _mm256_and_si256(c, mask); let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); - let lut = _mm256_loadu_si256(lut.as_ptr().add(i * 16).cast()); + let lut = _mm256_loadu_si256(lut.as_ptr().add(i).cast()); let res_lo = _mm512_zextsi256_si512(_mm256_shuffle_epi8(lut, clo)); accu_0 = _mm512_add_epi16(accu_0, res_lo); accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); @@ -133,14 +204,14 @@ mod fast_scan { i += 2; } - if i < width as usize { - let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + if i < n { + let c = _mm_loadu_si128(code.as_ptr().add(i).cast()); let mask = _mm_set1_epi8(0xf); let clo = _mm_and_si128(c, mask); let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); - let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let lut = _mm_loadu_si128(lut.as_ptr().add(i).cast()); let res_lo = _mm512_zextsi128_si512(_mm_shuffle_epi8(lut, clo)); accu_0 = _mm512_add_epi16(accu_0, res_lo); accu_1 = _mm512_add_epi16(accu_1, _mm512_srli_epi16(res_lo, 8)); @@ -150,7 +221,7 @@ mod fast_scan { i += 1; } - debug_assert_eq!(i, width as usize); + debug_assert_eq!(i, n); let mut result = [0_u16; 32]; @@ -178,14 +249,15 @@ mod fast_scan { return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { - for width in 90..110 { - let codes = (0..16 * width).map(|_| rand::random()).collect::>(); - let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + for n in 90..110 { + let code = (0..n) + .map(|_| [rand::random(), rand::random()]) + .collect::>(); + let lut = (0..n) + .map(|_| [rand::random(), rand::random()]) + .collect::>(); unsafe { - assert_eq!( - fast_scan_v4(width, &codes, &lut), - fast_scan_fallback(width, &codes, &lut) - ); + assert_eq!(fast_scan_v4(&code, &lut), fast_scan_fallback(&code, &lut)); } } } @@ -194,10 +266,10 @@ mod fast_scan { #[inline] #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v3")] - fn fast_scan_v3(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + fn fast_scan_v3(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually - assert_eq!(codes.len(), width as usize * 16); - assert_eq!(lut.len(), width as usize * 16); + assert_eq!(code.len(), lut.len()); + let n = code.len(); unsafe { use std::arch::x86_64::*; @@ -218,14 +290,14 @@ mod fast_scan { let mut accu_3 = _mm256_setzero_si256(); let mut i = 0_usize; - while i + 2 <= width as usize { - let c = _mm256_loadu_si256(codes.as_ptr().add(i * 16).cast()); + while i + 2 <= n { + let c = _mm256_loadu_si256(code.as_ptr().add(i).cast()); let mask = _mm256_set1_epi8(0xf); let clo = _mm256_and_si256(c, mask); let chi = _mm256_and_si256(_mm256_srli_epi16(c, 4), mask); - let lut = _mm256_loadu_si256(lut.as_ptr().add(i * 16).cast()); + let lut = _mm256_loadu_si256(lut.as_ptr().add(i).cast()); let res_lo = _mm256_shuffle_epi8(lut, clo); accu_0 = _mm256_add_epi16(accu_0, res_lo); accu_1 = _mm256_add_epi16(accu_1, _mm256_srli_epi16(res_lo, 8)); @@ -235,14 +307,14 @@ mod fast_scan { i += 2; } - if i < width as usize { - let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + if i < n { + let c = _mm_loadu_si128(code.as_ptr().add(i).cast()); let mask = _mm_set1_epi8(0xf); let clo = _mm_and_si128(c, mask); let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); - let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let lut = _mm_loadu_si128(lut.as_ptr().add(i).cast()); let res_lo = _mm256_zextsi128_si256(_mm_shuffle_epi8(lut, clo)); accu_0 = _mm256_add_epi16(accu_0, res_lo); accu_1 = _mm256_add_epi16(accu_1, _mm256_srli_epi16(res_lo, 8)); @@ -252,7 +324,7 @@ mod fast_scan { i += 1; } - debug_assert_eq!(i, width as usize); + debug_assert_eq!(i, n); let mut result = [0_u16; 32]; @@ -280,14 +352,15 @@ mod fast_scan { return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { - for width in 90..110 { - let codes = (0..16 * width).map(|_| rand::random()).collect::>(); - let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + for n in 90..110 { + let code = (0..n) + .map(|_| [rand::random(), rand::random()]) + .collect::>(); + let lut = (0..n) + .map(|_| [rand::random(), rand::random()]) + .collect::>(); unsafe { - assert_eq!( - fast_scan_v3(width, &codes, &lut), - fast_scan_fallback(width, &codes, &lut) - ); + assert_eq!(fast_scan_v3(&code, &lut), fast_scan_fallback(&code, &lut)); } } } @@ -295,10 +368,10 @@ mod fast_scan { #[cfg(target_arch = "x86_64")] #[crate::target_cpu(enable = "v2")] - fn fast_scan_v2(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + fn fast_scan_v2(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually - assert_eq!(codes.len(), width as usize * 16); - assert_eq!(lut.len(), width as usize * 16); + assert_eq!(code.len(), lut.len()); + let n = code.len(); unsafe { use std::arch::x86_64::*; @@ -309,14 +382,14 @@ mod fast_scan { let mut accu_3 = _mm_setzero_si128(); let mut i = 0_usize; - while i < width as usize { - let c = _mm_loadu_si128(codes.as_ptr().add(i * 16).cast()); + while i < n { + let c = _mm_loadu_si128(code.as_ptr().add(i).cast()); let mask = _mm_set1_epi8(0xf); let clo = _mm_and_si128(c, mask); let chi = _mm_and_si128(_mm_srli_epi16(c, 4), mask); - let lut = _mm_loadu_si128(lut.as_ptr().add(i * 16).cast()); + let lut = _mm_loadu_si128(lut.as_ptr().add(i).cast()); let res_lo = _mm_shuffle_epi8(lut, clo); accu_0 = _mm_add_epi16(accu_0, res_lo); accu_1 = _mm_add_epi16(accu_1, _mm_srli_epi16(res_lo, 8)); @@ -326,7 +399,7 @@ mod fast_scan { i += 1; } - debug_assert_eq!(i, width as usize); + debug_assert_eq!(i, n); let mut result = [0_u16; 32]; @@ -350,14 +423,15 @@ mod fast_scan { return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { - for width in 90..110 { - let codes = (0..16 * width).map(|_| rand::random()).collect::>(); - let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + for n in 90..110 { + let code = (0..n) + .map(|_| [rand::random(), rand::random()]) + .collect::>(); + let lut = (0..n) + .map(|_| [rand::random(), rand::random()]) + .collect::>(); unsafe { - assert_eq!( - fast_scan_v2(width, &codes, &lut), - fast_scan_fallback(width, &codes, &lut) - ); + assert_eq!(fast_scan_v2(&code, &lut), fast_scan_fallback(&code, &lut)); } } } @@ -365,10 +439,10 @@ mod fast_scan { #[cfg(target_arch = "aarch64")] #[crate::target_cpu(enable = "v8.3a")] - fn fast_scan_v8_3a(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { + fn fast_scan_v8_3a(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { // bounds checking is not enforced by compiler, so check it manually - assert_eq!(codes.len(), width as usize * 16); - assert_eq!(lut.len(), width as usize * 16); + assert_eq!(code.len(), lut.len()); + let n = code.len(); unsafe { use std::arch::aarch64::*; @@ -379,14 +453,14 @@ mod fast_scan { let mut accu_3 = vdupq_n_u16(0); let mut i = 0_usize; - while i < width as usize { - let c = vld1q_u8(codes.as_ptr().add(i * 16).cast()); + while i < n { + let c = vld1q_u8(code.as_ptr().add(i).cast()); let mask = vdupq_n_u8(0xf); let clo = vandq_u8(c, mask); let chi = vandq_u8(vshrq_n_u8(c, 4), mask); - let lut = vld1q_u8(lut.as_ptr().add(i * 16).cast()); + let lut = vld1q_u8(lut.as_ptr().add(i).cast()); let res_lo = vreinterpretq_u16_u8(vqtbl1q_u8(lut, clo)); accu_0 = vaddq_u16(accu_0, res_lo); accu_1 = vaddq_u16(accu_1, vshrq_n_u16(res_lo, 8)); @@ -396,7 +470,7 @@ mod fast_scan { i += 1; } - debug_assert_eq!(i, width as usize); + debug_assert_eq!(i, n); let mut result = [0_u16; 32]; @@ -420,13 +494,17 @@ mod fast_scan { return; } for _ in 0..if cfg!(not(miri)) { 256 } else { 1 } { - for width in 90..110 { - let codes = (0..16 * width).map(|_| rand::random()).collect::>(); - let lut = (0..16 * width).map(|_| rand::random()).collect::>(); + for n in 90..110 { + let code = (0..n) + .map(|_| [rand::random(), rand::random()]) + .collect::>(); + let lut = (0..n) + .map(|_| [rand::random(), rand::random()]) + .collect::>(); unsafe { assert_eq!( - fast_scan_v8_3a(width, &codes, &lut), - fast_scan_fallback(width, &codes, &lut) + fast_scan_v8_3a(&code, &lut), + fast_scan_fallback(&code, &lut) ); } } @@ -434,32 +512,24 @@ mod fast_scan { } #[crate::multiversion(@"v4", @"v3", @"v2", @"v8.3a")] - pub fn fast_scan(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { - let width = width as usize; - - assert_eq!(codes.len(), width * 16); - assert_eq!(lut.len(), width * 16); + pub fn fast_scan(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { + assert_eq!(code.len(), lut.len()); + let n = code.len(); - use std::array::from_fn; - use std::ops::BitAnd; - - fn load(slice: &[T]) -> [T; N] { - from_fn(|i| slice[i]) - } fn unary(op: impl Fn(T) -> U, a: [T; N]) -> [U; N] { - from_fn(|i| op(a[i])) + std::array::from_fn(|i| op(a[i])) } fn binary(op: impl Fn(T, T) -> T, a: [T; N], b: [T; N]) -> [T; N] { - from_fn(|i| op(a[i], b[i])) + std::array::from_fn(|i| op(a[i], b[i])) } fn shuffle(a: [T; N], b: [u8; N]) -> [T; N] { - from_fn(|i| a[b[i] as usize]) + std::array::from_fn(|i| a[b[i] as usize]) } fn cast(x: [u8; 16]) -> [u16; 8] { - from_fn(|i| u16::from_le_bytes([x[i << 1 | 0], x[i << 1 | 1]])) + std::array::from_fn(|i| u16::from_le_bytes([x[i << 1 | 0], x[i << 1 | 1]])) } fn setr(x: [[T; 8]; 4]) -> [T; 32] { - from_fn(|i| x[i >> 3][i & 7]) + std::array::from_fn(|i| x[i >> 3][i & 7]) } let mut a_0 = [0u16; 8]; @@ -467,14 +537,14 @@ mod fast_scan { let mut a_2 = [0u16; 8]; let mut a_3 = [0u16; 8]; - for i in 0..width { - let c = load(&codes[16 * i..]); + for i in 0..n { + let c = unsafe { std::mem::transmute::<[u64; 2], [u8; 16]>(code[i]) }; let mask = [0xfu8; 16]; - let clo = binary(u8::bitand, c, mask); - let chi = binary(u8::bitand, unary(|x| x >> 4, c), mask); + let clo = binary(std::ops::BitAnd::bitand, c, mask); + let chi = binary(std::ops::BitAnd::bitand, unary(|x| x >> 4, c), mask); - let lut = load(&lut[16 * i..]); + let lut = unsafe { std::mem::transmute::<[u64; 2], [u8; 16]>(lut[i]) }; let res_lo = cast(shuffle(lut, clo)); a_0 = binary(u16::wrapping_add, a_0, res_lo); a_1 = binary(u16::wrapping_add, a_1, unary(|x| x >> 8, res_lo)); @@ -491,6 +561,6 @@ mod fast_scan { } #[inline(always)] -pub fn fast_scan(width: u32, codes: &[u8], lut: &[u8]) -> [u16; 32] { - fast_scan::fast_scan(width, codes, lut) +pub fn fast_scan(code: &[[u64; 2]], lut: &[[u64; 2]]) -> [u16; 32] { + fast_scan::fast_scan(code, lut) } diff --git a/src/vchordrq/algorithm/build.rs b/src/algorithm/build.rs similarity index 78% rename from src/vchordrq/algorithm/build.rs rename to src/algorithm/build.rs index 4213e59..a0810a7 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/algorithm/build.rs @@ -1,25 +1,24 @@ -use crate::vchordrq::algorithm::rabitq; -use crate::vchordrq::algorithm::tuples::*; -use crate::vchordrq::index::am_options::Opfamily; -use crate::vchordrq::types::DistanceKind; -use crate::vchordrq::types::VchordrqBuildOptions; -use crate::vchordrq::types::VchordrqExternalBuildOptions; -use crate::vchordrq::types::VchordrqIndexingOptions; -use crate::vchordrq::types::VchordrqInternalBuildOptions; -use crate::vchordrq::types::VectorOptions; -use algorithm::{Page, PageGuard, RelationWrite}; +use crate::algorithm::RelationWrite; +use crate::algorithm::operator::{Operator, Vector}; +use crate::algorithm::tape::*; +use crate::algorithm::tuples::*; +use crate::index::am_options::Opfamily; +use crate::types::VchordrqBuildOptions; +use crate::types::VchordrqExternalBuildOptions; +use crate::types::VchordrqIndexingOptions; +use crate::types::VchordrqInternalBuildOptions; +use crate::types::VectorOptions; use rand::Rng; -use rkyv::ser::serializers::AllocSerializer; use simd::Floating; -use std::marker::PhantomData; use std::num::NonZeroU64; use std::sync::Arc; use vector::VectorBorrowed; +use vector::VectorOwned; -pub trait HeapRelation { +pub trait HeapRelation { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((NonZeroU64, V)); + F: FnMut((NonZeroU64, O::Vector)); fn opfamily(&self) -> Opfamily; } @@ -27,7 +26,7 @@ pub trait Reporter { fn tuples_total(&mut self, tuples_total: u64); } -pub fn build, R: Reporter>( +pub fn build, R: Reporter>( vector_options: VectorOptions, vchordrq_options: VchordrqIndexingOptions, heap_relation: T, @@ -35,8 +34,7 @@ pub fn build, R: Reporter>( mut reporter: R, ) { let dims = vector_options.dims; - let is_residual = - vchordrq_options.residual_quantization && vector_options.d == DistanceKind::L2; + let is_residual = vchordrq_options.residual_quantization && O::SUPPORTS_RESIDUAL; let structures = match vchordrq_options.build { VchordrqBuildOptions::External(external_build) => Structure::extern_build( vector_options.clone(), @@ -58,11 +56,11 @@ pub fn build, R: Reporter>( let vector = vector.as_borrowed(); assert_eq!(dims, vector.dims(), "invalid vector dimensions"); if number_of_samples < max_number_of_samples { - samples.push(V::build_to_vecf32(vector)); + samples.push(O::Vector::build_to_vecf32(vector)); number_of_samples += 1; } else { let index = rand.gen_range(0..max_number_of_samples) as usize; - samples[index] = V::build_to_vecf32(vector); + samples[index] = O::Vector::build_to_vecf32(vector); } tuples_total += 1; }); @@ -72,24 +70,32 @@ pub fn build, R: Reporter>( Structure::internal_build(vector_options.clone(), internal_build.clone(), samples) } }; - let mut meta = Tape::create(&relation, false); + let mut meta = TapeWriter::<_, _, MetaTuple>::create(|| relation.extend(false)); assert_eq!(meta.first(), 0); - let mut vectors = Tape::, _>::create(&relation, true); - let mut pointer_of_means = Vec::>::new(); + let freepage = TapeWriter::<_, _, FreepageTuple>::create(|| relation.extend(false)); + let mut vectors = TapeWriter::<_, _, VectorTuple>::create(|| relation.extend(true)); + let mut pointer_of_means = Vec::>::new(); for i in 0..structures.len() { let mut level = Vec::new(); for j in 0..structures[i].len() { - let vector = V::build_from_vecf32(&structures[i].means[j]); - let (metadata, slices) = V::vector_split(vector.as_borrowed()); - let mut chain = Err(metadata); + let vector = O::Vector::build_from_vecf32(&structures[i].means[j]); + let (metadata, slices) = O::Vector::vector_split(vector.as_borrowed()); + let mut chain = Ok(metadata); for i in (0..slices.len()).rev() { - chain = Ok(vectors.push(&VectorTuple { - payload: None, - slice: slices[i].to_vec(), - chain, + chain = Err(vectors.push(match chain { + Ok(metadata) => VectorTuple::_0 { + payload: None, + elements: slices[i].to_vec(), + metadata, + }, + Err(pointer) => VectorTuple::_1 { + payload: None, + elements: slices[i].to_vec(), + pointer, + }, })); } - level.push(chain.ok().unwrap()); + level.push(chain.err().unwrap()); } pointer_of_means.push(level); } @@ -98,10 +104,14 @@ pub fn build, R: Reporter>( let mut level = Vec::new(); for j in 0..structures[i].len() { if i == 0 { - let tape = Tape::::create(&relation, false); - level.push(tape.first()); + let tape = TapeWriter::<_, _, H0Tuple>::create(|| relation.extend(false)); + let mut jump = TapeWriter::<_, _, JumpTuple>::create(|| relation.extend(false)); + jump.push(JumpTuple { + first: tape.first(), + }); + level.push(jump.first()); } else { - let mut tape = Tape::::create(&relation, false); + let mut tape = H1TapeWriter::<_, _>::create(|| relation.extend(false)); let h2_mean = &structures[i].means[j]; let h2_children = &structures[i].children[j]; for child in h2_children.iter().copied() { @@ -111,28 +121,30 @@ pub fn build, R: Reporter>( } else { rabitq::code(dims, h1_mean) }; - tape.push(&Height1Tuple { + tape.push(H1Branch { mean: pointer_of_means[i - 1][child as usize], - first: pointer_of_firsts[i - 1][child as usize], dis_u_2: code.dis_u_2, factor_ppc: code.factor_ppc, factor_ip: code.factor_ip, factor_err: code.factor_err, - t: code.t(), + signs: code.signs, + first: pointer_of_firsts[i - 1][child as usize], }); } + let tape = tape.into_inner(); level.push(tape.first()); } } pointer_of_firsts.push(level); } - meta.push(&MetaTuple { + meta.push(MetaTuple { dims, height_of_root: structures.len() as u32, is_residual, vectors_first: vectors.first(), - mean: pointer_of_means.last().unwrap()[0], - first: pointer_of_firsts.last().unwrap()[0], + root_mean: pointer_of_means.last().unwrap()[0], + root_first: pointer_of_firsts.last().unwrap()[0], + freepage_first: freepage.first(), }); } @@ -342,7 +354,7 @@ impl Structure { ) -> (Vec>, Vec>) { labels .iter() - .filter(|(_, &(h, _))| h == height) + .filter(|(_, (h, _))| *h == height) .map(|(id, _)| { ( vectors[id].clone(), @@ -359,50 +371,3 @@ impl Structure { result } } - -struct Tape<'a: 'b, 'b, T, R: 'b + RelationWrite> { - relation: &'a R, - head: R::WriteGuard<'b>, - first: u32, - tracking_freespace: bool, - _phantom: PhantomData T>, -} - -impl<'a: 'b, 'b, T, R: 'b + RelationWrite> Tape<'a, 'b, T, R> { - fn create(relation: &'a R, tracking_freespace: bool) -> Self { - let mut head = relation.extend(tracking_freespace); - head.get_opaque_mut().skip = head.id(); - let first = head.id(); - Self { - relation, - head, - first, - tracking_freespace, - _phantom: PhantomData, - } - } - fn first(&self) -> u32 { - self.first - } -} - -impl<'a: 'b, 'b, T, R: 'b + RelationWrite> Tape<'a, 'b, T, R> -where - T: rkyv::Serialize>, -{ - fn push(&mut self, x: &T) -> (u32, u16) { - let bytes = rkyv::to_bytes(x).expect("failed to serialize"); - if let Some(i) = self.head.alloc(&bytes) { - (self.head.id(), i) - } else { - let next = self.relation.extend(self.tracking_freespace); - self.head.get_opaque_mut().next = next.id(); - self.head = next; - if let Some(i) = self.head.alloc(&bytes) { - (self.head.id(), i) - } else { - panic!("tuple is too large to fit in a fresh page") - } - } - } -} diff --git a/src/algorithm/freepages.rs b/src/algorithm/freepages.rs new file mode 100644 index 0000000..8984d3c --- /dev/null +++ b/src/algorithm/freepages.rs @@ -0,0 +1,61 @@ +use crate::algorithm::tuples::*; +use crate::algorithm::*; +use crate::utils::pipe::Pipe; +use std::cmp::Reverse; + +pub fn mark(relation: impl RelationWrite, freepage_first: u32, pages: &[u32]) { + let mut pages = pages.to_vec(); + pages.sort_by_key(|x| Reverse(*x)); + pages.dedup(); + let first = freepage_first; + assert!(first != u32::MAX); + let (mut current, mut offset) = (first, 0_u32); + while pages.is_empty() { + let locals = { + let mut local = Vec::new(); + while let Some(target) = pages.pop_if(|x| (offset..offset + 32768).contains(x)) { + local.push(target - offset); + } + local + }; + let mut freespace_guard = relation.write(current, false); + if freespace_guard.len() == 0 { + freespace_guard.alloc(&serialize(&FreepageTuple {})); + } + let mut freespace_tuple = freespace_guard + .get_mut(1) + .expect("data corruption") + .pipe(write_tuple::); + for local in locals { + freespace_tuple.mark(local as _); + } + if freespace_guard.get_opaque().next == u32::MAX { + let extend = relation.extend(false); + freespace_guard.get_opaque_mut().next = extend.id(); + } + (current, offset) = (freespace_guard.get_opaque().next, offset + 32768); + } +} + +pub fn fetch(relation: impl RelationWrite, freepage_first: u32) -> Option { + let first = freepage_first; + assert!(first != u32::MAX); + let (mut current, mut offset) = (first, 0_u32); + loop { + let mut freespace_guard = relation.write(current, false); + if freespace_guard.len() == 0 { + return None; + } + let mut freespace_tuple = freespace_guard + .get_mut(1) + .expect("data corruption") + .pipe(write_tuple::); + if let Some(local) = freespace_tuple.fetch() { + return Some(local as u32 + offset); + } + if freespace_guard.get_opaque().next == u32::MAX { + return None; + } + (current, offset) = (freespace_guard.get_opaque().next, offset + 32768); + } +} diff --git a/src/algorithm/insert.rs b/src/algorithm/insert.rs new file mode 100644 index 0000000..f2b8cfb --- /dev/null +++ b/src/algorithm/insert.rs @@ -0,0 +1,190 @@ +use crate::algorithm::operator::*; +use crate::algorithm::tape::read_h1_tape; +use crate::algorithm::tuples::*; +use crate::algorithm::vectors::{self}; +use crate::algorithm::{Page, PageGuard, RelationWrite}; +use crate::utils::pipe::Pipe; +use always_equal::AlwaysEqual; +use distance::Distance; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::num::NonZeroU64; +use vector::VectorBorrowed; +use vector::VectorOwned; + +pub fn insert( + relation: impl RelationWrite + Clone, + payload: NonZeroU64, + vector: O::Vector, +) { + let vector = O::Vector::random_projection(vector.as_borrowed()); + let meta_guard = relation.read(0); + let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); + let dims = meta_tuple.dims(); + let is_residual = meta_tuple.is_residual(); + let height_of_root = meta_tuple.height_of_root(); + assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions"); + let root_mean = meta_tuple.root_mean(); + let root_first = meta_tuple.root_first(); + let vectors_first = meta_tuple.vectors_first(); + drop(meta_guard); + + let default_lut_block = if !is_residual { + Some(O::Vector::compute_lut_block(vector.as_borrowed())) + } else { + None + }; + + let mean = vectors::vector_append::( + relation.clone(), + vectors_first, + vector.as_borrowed(), + payload, + ); + + type State = (u32, Option<::Vector>); + let mut state: State = { + let mean = root_mean; + if is_residual { + let residual_u = vectors::vector_access_1::( + relation.clone(), + mean, + LAccess::new( + O::Vector::elements_and_metadata(vector.as_borrowed()), + O::ResidualAccessor::default(), + ), + ); + (root_first, Some(residual_u)) + } else { + (root_first, None) + } + }; + let step = |state: State| { + let mut results = Vec::new(); + { + let (first, residual) = state; + let lut = if let Some(residual) = residual { + &O::Vector::compute_lut_block(residual.as_borrowed()) + } else { + default_lut_block.as_ref().unwrap() + }; + read_h1_tape( + relation.clone(), + first, + || { + RAccess::new( + (&lut.4, (lut.0, lut.1, lut.2, lut.3, 1.9f32)), + O::Distance::block_accessor(), + ) + }, + |lowerbound, mean, first| { + results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(first))); + }, + ); + } + let mut heap = BinaryHeap::from(results); + let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); + { + while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { + let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); + if is_residual { + let (dis_u, residual_u) = vectors::vector_access_1::( + relation.clone(), + mean, + LAccess::new( + O::Vector::elements_and_metadata(vector.as_borrowed()), + ( + O::DistanceAccessor::default(), + O::ResidualAccessor::default(), + ), + ), + ); + cache.push(( + Reverse(dis_u), + AlwaysEqual(first), + AlwaysEqual(Some(residual_u)), + )); + } else { + let dis_u = vectors::vector_access_1::( + relation.clone(), + mean, + LAccess::new( + O::Vector::elements_and_metadata(vector.as_borrowed()), + O::DistanceAccessor::default(), + ), + ); + cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(None))); + } + } + let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop().unwrap(); + (first, mean) + } + }; + for _ in (1..height_of_root).rev() { + state = step(state); + } + + let (first, residual) = state; + let code = if let Some(residual) = residual { + O::Vector::code(residual.as_borrowed()) + } else { + O::Vector::code(vector.as_borrowed()) + }; + let bytes = serialize(&H0Tuple::_0 { + mean, + dis_u_2: code.dis_u_2, + factor_ppc: code.factor_ppc, + factor_ip: code.factor_ip, + factor_err: code.factor_err, + payload: Some(payload), + elements: rabitq::pack_to_u64(&code.signs), + }); + + let jump_guard = relation.read(first); + let jump_tuple = jump_guard + .get(1) + .expect("data corruption") + .pipe(read_tuple::); + + let first = jump_tuple.first(); + + assert!(first != u32::MAX); + let mut current = first; + loop { + let read = relation.read(current); + if read.get_opaque().next == u32::MAX { + drop(read); + let mut write = relation.write(current, false); + if write.get_opaque().next == u32::MAX { + if write.alloc(&bytes).is_some() { + return; + } + let mut extend = relation.extend(false); + write.get_opaque_mut().next = extend.id(); + drop(write); + let fresh = extend.id(); + if extend.alloc(&bytes).is_some() { + drop(extend); + let mut past = relation.write(first, false); + past.get_opaque_mut().skip = std::cmp::max(past.get_opaque_mut().skip, fresh); + drop(past); + return; + } else { + panic!("a tuple cannot even be fit in a fresh page"); + } + } else { + if current == first && write.get_opaque().skip != first { + current = write.get_opaque().skip; + } else { + current = write.get_opaque().next; + } + } + } else { + if current == first && read.get_opaque().skip != first { + current = read.get_opaque().skip; + } else { + current = read.get_opaque().next; + } + } + } +} diff --git a/crates/algorithm/src/lib.rs b/src/algorithm/mod.rs similarity index 85% rename from crates/algorithm/src/lib.rs rename to src/algorithm/mod.rs index 8bee016..7a9c3ff 100644 --- a/crates/algorithm/src/lib.rs +++ b/src/algorithm/mod.rs @@ -1,3 +1,14 @@ +pub mod build; +pub mod freepages; +pub mod insert; +pub mod operator; +pub mod prewarm; +pub mod scan; +pub mod tape; +pub mod tuples; +pub mod vacuum; +pub mod vectors; + use std::ops::{Deref, DerefMut}; #[repr(C, align(8))] @@ -15,8 +26,8 @@ pub trait Page: Sized { fn get_mut(&mut self, i: u16) -> Option<&mut [u8]>; fn alloc(&mut self, data: &[u8]) -> Option; fn free(&mut self, i: u16); - fn reconstruct(&mut self, removes: &[u16]); fn freespace(&self) -> u16; + fn clear(&mut self); } pub trait PageGuard { diff --git a/src/algorithm/operator.rs b/src/algorithm/operator.rs new file mode 100644 index 0000000..9506d7a --- /dev/null +++ b/src/algorithm/operator.rs @@ -0,0 +1,628 @@ +use crate::types::{DistanceKind, OwnedVector}; +use distance::Distance; +use half::f16; +use simd::Floating; +use std::fmt::Debug; +use std::marker::PhantomData; +use vector::vect::VectOwned; +use vector::{VectorBorrowed, VectorOwned}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; + +pub trait Accessor2 { + type Output; + fn push(&mut self, input: &[E0], target: &[E1]); + fn finish(self, input: M0, target: M1) -> Self::Output; +} + +impl Accessor2 for () { + type Output = (); + + fn push(&mut self, _: &[E0], _: &[E1]) {} + + fn finish(self, _: M0, _: M1) -> Self::Output {} +} + +impl> Accessor2 for (A,) { + type Output = (A::Output,); + + fn push(&mut self, input: &[E0], target: &[E1]) { + self.0.push(input, target); + } + + fn finish(self, input: M0, target: M1) -> Self::Output { + (self.0.finish(input, target),) + } +} + +impl, B: Accessor2> + Accessor2 for (A, B) +{ + type Output = (A::Output, B::Output); + + fn push(&mut self, input: &[E0], target: &[E1]) { + self.0.push(input, target); + self.1.push(input, target); + } + + fn finish(self, input: M0, target: M1) -> Self::Output { + (self.0.finish(input, target), self.1.finish(input, target)) + } +} + +#[derive(Debug)] +pub struct Sum(f32, PhantomData O>); + +impl Default for Sum { + fn default() -> Self { + Self(0.0, PhantomData) + } +} + +impl Accessor2 for Sum, L2>> { + type Output = Distance; + + fn push(&mut self, target: &[f32], input: &[f32]) { + self.0 += f32::reduce_sum_of_d2(target, input) + } + + fn finish(self, (): (), (): ()) -> Self::Output { + Distance::from_f32(self.0) + } +} + +impl Accessor2 for Sum, Dot>> { + type Output = Distance; + + fn push(&mut self, target: &[f32], input: &[f32]) { + self.0 += f32::reduce_sum_of_xy(target, input) + } + + fn finish(self, (): (), (): ()) -> Self::Output { + Distance::from_f32(-self.0) + } +} + +impl Accessor2 for Sum, L2>> { + type Output = Distance; + + fn push(&mut self, target: &[f16], input: &[f16]) { + self.0 += f16::reduce_sum_of_d2(target, input) + } + + fn finish(self, (): (), (): ()) -> Self::Output { + Distance::from_f32(self.0) + } +} + +impl Accessor2 for Sum, Dot>> { + type Output = Distance; + + fn push(&mut self, target: &[f16], input: &[f16]) { + self.0 += f16::reduce_sum_of_xy(target, input) + } + + fn finish(self, (): (), (): ()) -> Self::Output { + Distance::from_f32(-self.0) + } +} + +#[derive(Debug, Clone)] +pub struct Diff(Vec<::Element>); + +impl Default for Diff { + fn default() -> Self { + Self(Vec::new()) + } +} + +impl Accessor2 for Diff, L2>> { + type Output = VectOwned; + + fn push(&mut self, target: &[f32], input: &[f32]) { + self.0.extend(f32::vector_sub(target, input)); + } + + fn finish(self, (): (), (): ()) -> Self::Output { + VectOwned::new(self.0) + } +} + +impl Accessor2 for Diff, Dot>> { + type Output = VectOwned; + + fn push(&mut self, target: &[f32], input: &[f32]) { + self.0.extend(f32::vector_sub(target, input)); + } + + fn finish(self, (): (), (): ()) -> Self::Output { + VectOwned::new(self.0) + } +} + +impl Accessor2 for Diff, L2>> { + type Output = VectOwned; + + fn push(&mut self, target: &[f16], input: &[f16]) { + self.0.extend(f16::vector_sub(target, input)); + } + + fn finish(self, (): (), (): ()) -> Self::Output { + VectOwned::new(self.0) + } +} + +impl Accessor2 for Diff, Dot>> { + type Output = VectOwned; + + fn push(&mut self, target: &[f16], input: &[f16]) { + self.0.extend(f16::vector_sub(target, input)); + } + + fn finish(self, (): (), (): ()) -> Self::Output { + VectOwned::new(self.0) + } +} + +#[derive(Debug)] +pub struct Block([u16; 32], PhantomData D>); + +impl Default for Block { + fn default() -> Self { + Self([0u16; 32], PhantomData) + } +} + +impl + Accessor2< + [u64; 2], + [u64; 2], + (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32]), + (f32, f32, f32, f32, f32), + > for Block +{ + type Output = [Distance; 32]; + + fn push(&mut self, input: &[[u64; 2]], target: &[[u64; 2]]) { + let t = simd::fast_scan::fast_scan(input, target); + for i in 0..32 { + self.0[i] += t[i]; + } + } + + fn finish( + self, + (dis_u_2, factor_ppc, factor_ip, factor_err): ( + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[f32; 32], + ), + (dis_v_2, b, k, qvector_sum, epsilon): (f32, f32, f32, f32, f32), + ) -> Self::Output { + std::array::from_fn(|i| { + let rough = dis_u_2[i] + + dis_v_2 + + b * factor_ppc[i] + + ((2.0 * self.0[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }) + } +} + +impl + Accessor2< + [u64; 2], + [u64; 2], + (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32]), + (f32, f32, f32, f32, f32), + > for Block +{ + type Output = [Distance; 32]; + + fn push(&mut self, input: &[[u64; 2]], target: &[[u64; 2]]) { + let t = simd::fast_scan::fast_scan(input, target); + for i in 0..32 { + self.0[i] += t[i]; + } + } + + fn finish( + self, + (_, factor_ppc, factor_ip, factor_err): (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32]), + (dis_v_2, b, k, qvector_sum, epsilon): (f32, f32, f32, f32, f32), + ) -> Self::Output { + std::array::from_fn(|i| { + let rough = 0.5 * b * factor_ppc[i] + + 0.5 * ((2.0 * self.0[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = 0.5 * factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }) + } +} + +pub trait Accessor1 { + type Output; + fn push(&mut self, input: &[E]); + fn finish(self, input: M) -> Self::Output; +} + +impl Accessor1 for () { + type Output = (); + + fn push(&mut self, _: &[E]) {} + + fn finish(self, _: M) -> Self::Output {} +} + +impl Accessor1 for (A,) +where + A: Accessor1, +{ + type Output = (A::Output,); + + fn push(&mut self, input: &[E]) { + self.0.push(input); + } + + fn finish(self, input: M) -> Self::Output { + (self.0.finish(input),) + } +} + +impl Accessor1 for (A, B) +where + A: Accessor1, + B: Accessor1, +{ + type Output = (A::Output, B::Output); + + fn push(&mut self, input: &[E]) { + self.0.push(input); + self.1.push(input); + } + + fn finish(self, input: M) -> Self::Output { + (self.0.finish(input), self.1.finish(input)) + } +} + +pub struct LAccess<'a, E, M, A> { + elements: &'a [E], + metadata: M, + accessor: A, +} + +impl<'a, E, M, A> LAccess<'a, E, M, A> { + pub fn new((elements, metadata): (&'a [E], M), accessor: A) -> Self { + Self { + elements, + metadata, + accessor, + } + } +} + +impl> Accessor1 for LAccess<'_, E0, M0, A> { + type Output = A::Output; + + fn push(&mut self, rhs: &[E1]) { + let (lhs, elements) = self.elements.split_at(rhs.len()); + self.accessor.push(lhs, rhs); + self.elements = elements; + } + + fn finish(self, rhs: M1) -> Self::Output { + self.accessor.finish(self.metadata, rhs) + } +} + +pub struct RAccess<'a, E, M, A> { + elements: &'a [E], + metadata: M, + accessor: A, +} + +impl<'a, E, M, A> RAccess<'a, E, M, A> { + #[allow(dead_code)] + pub fn new((elements, metadata): (&'a [E], M), accessor: A) -> Self { + Self { + elements, + metadata, + accessor, + } + } +} + +impl> Accessor1 for RAccess<'_, E1, M1, A> { + type Output = A::Output; + + fn push(&mut self, lhs: &[E0]) { + let (rhs, elements) = self.elements.split_at(lhs.len()); + self.accessor.push(lhs, rhs); + self.elements = elements; + } + + fn finish(self, lhs: M0) -> Self::Output { + self.accessor.finish(lhs, self.metadata) + } +} + +pub trait Vector: VectorOwned { + type Element: Debug + Copy + FromBytes + IntoBytes + Immutable + KnownLayout; + type Metadata: Debug + Copy + FromBytes + IntoBytes + Immutable + KnownLayout; + + fn vector_split(vector: Self::Borrowed<'_>) -> (Self::Metadata, Vec<&[Self::Element]>); + fn elements_and_metadata(vector: Self::Borrowed<'_>) -> (&[Self::Element], Self::Metadata); + fn from_owned(vector: OwnedVector) -> Self; + + fn random_projection(vector: Self::Borrowed<'_>) -> Self; + + fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u64; 2]>); + + fn compute_lut( + vector: Self::Borrowed<'_>, + ) -> ( + (f32, f32, f32, f32, Vec<[u64; 2]>), + (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), + ); + + fn code(vector: Self::Borrowed<'_>) -> rabitq::Code; + + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec; + + fn build_from_vecf32(x: &[f32]) -> Self; +} + +impl Vector for VectOwned { + type Metadata = (); + + type Element = f32; + + fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f32]>) { + let vector = vector.slice(); + ((), match vector.len() { + 0..=960 => vec![vector], + 961..=1280 => vec![&vector[..640], &vector[640..]], + 1281.. => vector.chunks(1920).collect(), + }) + } + + fn elements_and_metadata(vector: Self::Borrowed<'_>) -> (&[Self::Element], Self::Metadata) { + (vector.slice(), ()) + } + + fn from_owned(vector: OwnedVector) -> Self { + match vector { + OwnedVector::Vecf32(x) => x, + _ => unreachable!(), + } + } + + fn random_projection(vector: Self::Borrowed<'_>) -> Self { + Self::new(crate::projection::project(vector.slice())) + } + + fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u64; 2]>) { + rabitq::block::preprocess(vector.slice()) + } + + fn compute_lut( + vector: Self::Borrowed<'_>, + ) -> ( + (f32, f32, f32, f32, Vec<[u64; 2]>), + (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), + ) { + rabitq::compute_lut(vector.slice()) + } + + fn code(vector: Self::Borrowed<'_>) -> rabitq::Code { + rabitq::code(vector.dims(), vector.slice()) + } + + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { + vector.slice().to_vec() + } + + fn build_from_vecf32(x: &[f32]) -> Self { + Self::new(x.to_vec()) + } +} + +impl Vector for VectOwned { + type Metadata = (); + + type Element = f16; + + fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f16]>) { + let vector = vector.slice(); + ((), match vector.len() { + 0..=1920 => vec![vector], + 1921..=2560 => vec![&vector[..1280], &vector[1280..]], + 2561.. => vector.chunks(3840).collect(), + }) + } + + fn elements_and_metadata(vector: Self::Borrowed<'_>) -> (&[Self::Element], Self::Metadata) { + (vector.slice(), ()) + } + + fn from_owned(vector: OwnedVector) -> Self { + match vector { + OwnedVector::Vecf16(x) => x, + _ => unreachable!(), + } + } + + fn random_projection(vector: Self::Borrowed<'_>) -> Self { + Self::new(f16::vector_from_f32(&crate::projection::project( + &f16::vector_to_f32(vector.slice()), + ))) + } + + fn compute_lut_block(vector: Self::Borrowed<'_>) -> (f32, f32, f32, f32, Vec<[u64; 2]>) { + rabitq::block::preprocess(&f16::vector_to_f32(vector.slice())) + } + + fn compute_lut( + vector: Self::Borrowed<'_>, + ) -> ( + (f32, f32, f32, f32, Vec<[u64; 2]>), + (f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), + ) { + rabitq::compute_lut(&f16::vector_to_f32(vector.slice())) + } + + fn code(vector: Self::Borrowed<'_>) -> rabitq::Code { + rabitq::code(vector.dims(), &f16::vector_to_f32(vector.slice())) + } + + fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { + f16::vector_to_f32(vector.slice()) + } + + fn build_from_vecf32(x: &[f32]) -> Self { + Self::new(f16::vector_from_f32(x)) + } +} + +pub trait OperatorDistance: 'static + Debug + Copy { + const KIND: DistanceKind; + + fn compute_lowerbound_binary( + lut: &(f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), + code: (f32, f32, f32, f32, &[u64]), + epsilon: f32, + ) -> Distance; + + type BlockAccessor: for<'a> Accessor2< + [u64; 2], + [u64; 2], + (&'a [f32; 32], &'a [f32; 32], &'a [f32; 32], &'a [f32; 32]), + (f32, f32, f32, f32, f32), + Output = [Distance; 32], + > + Default; + + fn block_accessor() -> Self::BlockAccessor { + Default::default() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct L2; + +impl OperatorDistance for L2 { + const KIND: DistanceKind = DistanceKind::L2; + + fn compute_lowerbound_binary( + lut: &(f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), + code: (f32, f32, f32, f32, &[u64]), + epsilon: f32, + ) -> Distance { + rabitq::binary::process_lowerbound_l2(lut, code, epsilon) + } + + type BlockAccessor = Block; +} + +#[derive(Debug, Clone, Copy)] +pub struct Dot; + +impl OperatorDistance for Dot { + const KIND: DistanceKind = DistanceKind::Dot; + + fn compute_lowerbound_binary( + lut: &(f32, f32, f32, f32, (Vec, Vec, Vec, Vec)), + code: (f32, f32, f32, f32, &[u64]), + epsilon: f32, + ) -> Distance { + rabitq::binary::process_lowerbound_dot(lut, code, epsilon) + } + + type BlockAccessor = Block; +} + +pub trait Operator: 'static + Debug + Copy { + type Vector: Vector; + + type Distance: OperatorDistance; + + type DistanceAccessor: Default + + Accessor2< + ::Element, + ::Element, + ::Metadata, + ::Metadata, + Output = Distance, + >; + + const SUPPORTS_RESIDUAL: bool; + + type ResidualAccessor: Default + + Accessor2< + ::Element, + ::Element, + ::Metadata, + ::Metadata, + Output = Self::Vector, + >; +} + +#[derive(Debug)] +pub struct Op(PhantomData<(fn(V) -> V, fn(D) -> D)>); + +impl Clone for Op { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for Op {} + +impl Operator for Op, L2> { + type Vector = VectOwned; + + type Distance = L2; + + type DistanceAccessor = Sum, L2>>; + + const SUPPORTS_RESIDUAL: bool = true; + + type ResidualAccessor = Diff, L2>>; +} + +impl Operator for Op, Dot> { + type Vector = VectOwned; + + type Distance = Dot; + + type DistanceAccessor = Sum, Dot>>; + + const SUPPORTS_RESIDUAL: bool = false; + + type ResidualAccessor = Diff, Dot>>; +} + +impl Operator for Op, L2> { + type Vector = VectOwned; + + type Distance = L2; + + type DistanceAccessor = Sum, L2>>; + + const SUPPORTS_RESIDUAL: bool = true; + + type ResidualAccessor = Diff, L2>>; +} + +impl Operator for Op, Dot> { + type Vector = VectOwned; + + type Distance = Dot; + + type DistanceAccessor = Sum, Dot>>; + + const SUPPORTS_RESIDUAL: bool = false; + + type ResidualAccessor = Diff, Dot>>; +} diff --git a/src/algorithm/prewarm.rs b/src/algorithm/prewarm.rs new file mode 100644 index 0000000..9373f4a --- /dev/null +++ b/src/algorithm/prewarm.rs @@ -0,0 +1,99 @@ +use crate::algorithm::operator::Operator; +use crate::algorithm::tuples::*; +use crate::algorithm::vectors; +use crate::algorithm::{Page, RelationRead}; +use crate::utils::pipe::Pipe; +use std::fmt::Write; + +pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> String { + let meta_guard = relation.read(0); + let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); + let height_of_root = meta_tuple.height_of_root(); + let root_mean = meta_tuple.root_mean(); + let root_first = meta_tuple.root_first(); + drop(meta_guard); + + let mut message = String::new(); + writeln!(message, "height of root: {}", height_of_root).unwrap(); + let prewarm_max_height = if height < 0 { 0 } else { height as u32 }; + if prewarm_max_height > height_of_root { + return message; + } + type State = Vec; + let mut state: State = { + let mut results = Vec::new(); + let counter = 1_usize; + { + vectors::vector_access_1::(relation.clone(), root_mean, ()); + results.push(root_first); + } + writeln!(message, "number of tuples: {}", results.len()).unwrap(); + writeln!(message, "number of pages: {}", counter).unwrap(); + results + }; + let mut step = |state: State| { + let mut counter = 0_usize; + let mut results = Vec::new(); + for list in state { + let mut current = list; + while current != u32::MAX { + counter += 1; + pgrx::check_for_interrupts!(); + let h1_guard = relation.read(current); + for i in 1..=h1_guard.len() { + let h1_tuple = h1_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h1_tuple { + H1TupleReader::_0(h1_tuple) => { + for mean in h1_tuple.mean().iter().copied() { + vectors::vector_access_1::(relation.clone(), mean, ()); + } + for first in h1_tuple.first().iter().copied() { + results.push(first); + } + } + H1TupleReader::_1(_) => (), + } + } + current = h1_guard.get_opaque().next; + } + } + writeln!(message, "number of tuples: {}", results.len()).unwrap(); + writeln!(message, "number of pages: {}", counter).unwrap(); + results + }; + for _ in (std::cmp::max(1, prewarm_max_height)..height_of_root).rev() { + state = step(state); + } + if prewarm_max_height == 0 { + let mut counter = 0_usize; + let mut results = Vec::new(); + for list in state { + let jump_guard = relation.read(list); + let jump_tuple = jump_guard + .get(1) + .expect("data corruption") + .pipe(read_tuple::); + let first = jump_tuple.first(); + let mut current = first; + while current != u32::MAX { + counter += 1; + pgrx::check_for_interrupts!(); + let h0_guard = relation.read(current); + for i in 1..=h0_guard.len() { + let _h0_tuple = h0_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + results.push(()); + } + current = h0_guard.get_opaque().next; + } + } + writeln!(message, "number of tuples: {}", results.len()).unwrap(); + writeln!(message, "number of pages: {}", counter).unwrap(); + } + message +} diff --git a/src/algorithm/scan.rs b/src/algorithm/scan.rs new file mode 100644 index 0000000..fee6df3 --- /dev/null +++ b/src/algorithm/scan.rs @@ -0,0 +1,171 @@ +use crate::algorithm::operator::*; +use crate::algorithm::tape::read_h0_tape; +use crate::algorithm::tape::read_h1_tape; +use crate::algorithm::tuples::*; +use crate::algorithm::vectors; +use crate::algorithm::{Page, RelationRead}; +use crate::utils::pipe::Pipe; +use always_equal::AlwaysEqual; +use distance::Distance; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::num::NonZeroU64; +use vector::VectorBorrowed; +use vector::VectorOwned; + +pub fn scan( + relation: impl RelationRead + Clone, + vector: O::Vector, + probes: Vec, + epsilon: f32, +) -> impl Iterator { + let vector = O::Vector::random_projection(vector.as_borrowed()); + let meta_guard = relation.read(0); + let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); + let dims = meta_tuple.dims(); + let is_residual = meta_tuple.is_residual(); + let height_of_root = meta_tuple.height_of_root(); + assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions"); + assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes"); + let root_mean = meta_tuple.root_mean(); + let root_first = meta_tuple.root_first(); + drop(meta_guard); + + let default_lut = if !is_residual { + Some(O::Vector::compute_lut(vector.as_borrowed())) + } else { + None + }; + + type State = Vec<(u32, Option<::Vector>)>; + let mut state: State = vec![{ + let mean = root_mean; + if is_residual { + let residual_u = vectors::vector_access_1::( + relation.clone(), + mean, + LAccess::new( + O::Vector::elements_and_metadata(vector.as_borrowed()), + O::ResidualAccessor::default(), + ), + ); + (root_first, Some(residual_u)) + } else { + (root_first, None) + } + }]; + let step = |state: State, probes| { + let mut results = Vec::new(); + for (first, residual) in state { + let lut = if let Some(residual) = residual { + &O::Vector::compute_lut_block(residual.as_borrowed()) + } else { + default_lut.as_ref().map(|x| &x.0).unwrap() + }; + read_h1_tape( + relation.clone(), + first, + || { + RAccess::new( + (&lut.4, (lut.0, lut.1, lut.2, lut.3, epsilon)), + O::Distance::block_accessor(), + ) + }, + |lowerbound, mean, first| { + results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(first))); + }, + ); + } + let mut heap = BinaryHeap::from(results); + let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); + std::iter::from_fn(|| { + while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { + let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); + if is_residual { + let (dis_u, residual_u) = vectors::vector_access_1::( + relation.clone(), + mean, + LAccess::new( + O::Vector::elements_and_metadata(vector.as_borrowed()), + ( + O::DistanceAccessor::default(), + O::ResidualAccessor::default(), + ), + ), + ); + cache.push(( + Reverse(dis_u), + AlwaysEqual(first), + AlwaysEqual(Some(residual_u)), + )); + } else { + let dis_u = vectors::vector_access_1::( + relation.clone(), + mean, + LAccess::new( + O::Vector::elements_and_metadata(vector.as_borrowed()), + O::DistanceAccessor::default(), + ), + ); + cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(None))); + } + } + let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop()?; + Some((first, mean)) + }) + .take(probes as usize) + .collect() + }; + for i in (1..height_of_root).rev() { + state = step(state, probes[i as usize - 1]); + } + + let mut results = Vec::new(); + for (first, residual) in state { + let lut = if let Some(residual) = residual.as_ref().map(|x| x.as_borrowed()) { + &O::Vector::compute_lut(residual) + } else { + default_lut.as_ref().unwrap() + }; + let jump_guard = relation.read(first); + let jump_tuple = jump_guard + .get(1) + .expect("data corruption") + .pipe(read_tuple::); + let first = jump_tuple.first(); + read_h0_tape( + relation.clone(), + first, + || { + RAccess::new( + (&lut.0.4, (lut.0.0, lut.0.1, lut.0.2, lut.0.3, epsilon)), + O::Distance::block_accessor(), + ) + }, + |code| O::Distance::compute_lowerbound_binary(&lut.1, code, epsilon), + |lowerbound, mean, payload| { + results.push((Reverse(lowerbound), AlwaysEqual(mean), AlwaysEqual(payload))); + }, + ); + } + let mut heap = BinaryHeap::from(results); + let mut cache = BinaryHeap::<(Reverse, _)>::new(); + std::iter::from_fn(move || { + while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { + let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap(); + if let Some(dis_u) = vectors::vector_access_0::( + relation.clone(), + mean, + pay_u, + LAccess::new( + O::Vector::elements_and_metadata(vector.as_borrowed()), + O::DistanceAccessor::default(), + ), + ) { + cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); + }; + } + let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; + Some((dis_u, pay_u)) + }) +} diff --git a/src/algorithm/tape.rs b/src/algorithm/tape.rs new file mode 100644 index 0000000..4ee722a --- /dev/null +++ b/src/algorithm/tape.rs @@ -0,0 +1,349 @@ +use super::RelationRead; +use super::operator::Accessor1; +use crate::algorithm::Page; +use crate::algorithm::PageGuard; +use crate::algorithm::tuples::*; +use crate::utils::pipe::Pipe; +use distance::Distance; +use simd::fast_scan::any_pack; +use simd::fast_scan::padding_pack; +use std::marker::PhantomData; +use std::num::NonZeroU64; +use std::ops::DerefMut; + +pub struct TapeWriter { + head: G, + first: u32, + extend: E, + _phantom: PhantomData T>, +} + +impl TapeWriter +where + G: PageGuard + DerefMut, + G::Target: Page, + E: Fn() -> G, +{ + pub fn create(extend: E) -> Self { + let mut head = extend(); + head.get_opaque_mut().skip = head.id(); + let first = head.id(); + Self { + head, + first, + extend, + _phantom: PhantomData, + } + } + pub fn first(&self) -> u32 { + self.first + } + fn freespace(&self) -> u16 { + self.head.freespace() + } + fn tape_move(&mut self) { + if self.head.len() == 0 { + panic!("tuple is too large to fit in a fresh page"); + } + let next = (self.extend)(); + self.head.get_opaque_mut().next = next.id(); + self.head = next; + } +} + +impl TapeWriter +where + G: PageGuard + DerefMut, + G::Target: Page, + E: Fn() -> G, + T: Tuple, +{ + pub fn push(&mut self, x: T) -> IndexPointer { + let bytes = serialize(&x); + if let Some(i) = self.head.alloc(&bytes) { + pair_to_pointer((self.head.id(), i)) + } else { + let next = (self.extend)(); + self.head.get_opaque_mut().next = next.id(); + self.head = next; + if let Some(i) = self.head.alloc(&bytes) { + pair_to_pointer((self.head.id(), i)) + } else { + panic!("tuple is too large to fit in a fresh page") + } + } + } + fn tape_put(&mut self, x: T) -> IndexPointer { + let bytes = serialize(&x); + if let Some(i) = self.head.alloc(&bytes) { + pair_to_pointer((self.head.id(), i)) + } else { + panic!("tuple is too large to fit in this page") + } + } +} + +pub struct H1Branch { + pub mean: IndexPointer, + pub dis_u_2: f32, + pub factor_ppc: f32, + pub factor_ip: f32, + pub factor_err: f32, + pub signs: Vec, + pub first: u32, +} + +pub struct H1TapeWriter { + tape: TapeWriter, + branches: Vec, +} + +impl H1TapeWriter +where + G: PageGuard + DerefMut, + G::Target: Page, + E: Fn() -> G, +{ + pub fn create(extend: E) -> Self { + Self { + tape: TapeWriter::create(extend), + branches: Vec::new(), + } + } + pub fn push(&mut self, branch: H1Branch) { + self.branches.push(branch); + if self.branches.len() == 32 { + let chunk = std::array::from_fn::<_, 32, _>(|_| self.branches.pop().unwrap()); + let mut remain = padding_pack(chunk.iter().map(|x| rabitq::pack_to_u4(&x.signs))); + loop { + let freespace = self.tape.freespace(); + match (H1Tuple::fit_0(freespace), H1Tuple::fit_1(freespace)) { + (Some(w), _) if w >= remain.len() => { + self.tape.tape_put(H1Tuple::_0 { + mean: chunk.each_ref().map(|x| x.mean), + dis_u_2: chunk.each_ref().map(|x| x.dis_u_2), + factor_ppc: chunk.each_ref().map(|x| x.factor_ppc), + factor_ip: chunk.each_ref().map(|x| x.factor_ip), + factor_err: chunk.each_ref().map(|x| x.factor_err), + first: chunk.each_ref().map(|x| x.first), + len: chunk.len() as _, + elements: remain, + }); + break; + } + (_, Some(w)) => { + let (left, right) = remain.split_at(std::cmp::min(w, remain.len())); + self.tape.tape_put(H1Tuple::_1 { + elements: left.to_vec(), + }); + remain = right.to_vec(); + } + (_, None) => self.tape.tape_move(), + } + } + } + } + pub fn into_inner(mut self) -> TapeWriter { + let chunk = self.branches; + if !chunk.is_empty() { + let mut remain = padding_pack(chunk.iter().map(|x| rabitq::pack_to_u4(&x.signs))); + loop { + let freespace = self.tape.freespace(); + match (H1Tuple::fit_0(freespace), H1Tuple::fit_1(freespace)) { + (Some(w), _) if w >= remain.len() => { + self.tape.push(H1Tuple::_0 { + mean: any_pack(chunk.iter().map(|x| x.mean)), + dis_u_2: any_pack(chunk.iter().map(|x| x.dis_u_2)), + factor_ppc: any_pack(chunk.iter().map(|x| x.factor_ppc)), + factor_ip: any_pack(chunk.iter().map(|x| x.factor_ip)), + factor_err: any_pack(chunk.iter().map(|x| x.factor_err)), + first: any_pack(chunk.iter().map(|x| x.first)), + len: chunk.len() as _, + elements: remain, + }); + break; + } + (_, Some(w)) => { + let (left, right) = remain.split_at(std::cmp::min(w, remain.len())); + self.tape.tape_put(H1Tuple::_1 { + elements: left.to_vec(), + }); + remain = right.to_vec(); + } + (_, None) => self.tape.tape_move(), + } + } + } + self.tape + } +} + +pub struct H0BranchWriter { + pub mean: IndexPointer, + pub dis_u_2: f32, + pub factor_ppc: f32, + pub factor_ip: f32, + pub factor_err: f32, + pub signs: Vec, + pub payload: NonZeroU64, +} + +pub struct H0Tape { + tape: TapeWriter, + branches: Vec, +} + +impl H0Tape +where + G: PageGuard + DerefMut, + G::Target: Page, + E: Fn() -> G, +{ + pub fn create(extend: E) -> Self { + Self { + tape: TapeWriter::create(extend), + branches: Vec::new(), + } + } + pub fn push(&mut self, branch: H0BranchWriter) { + self.branches.push(branch); + if self.branches.len() == 32 { + let chunk = std::array::from_fn::<_, 32, _>(|_| self.branches.pop().unwrap()); + let mut remain = padding_pack(chunk.iter().map(|x| rabitq::pack_to_u4(&x.signs))); + loop { + let freespace = self.tape.freespace(); + match (H0Tuple::fit_1(freespace), H0Tuple::fit_2(freespace)) { + (Some(w), _) if w >= remain.len() => { + self.tape.push(H0Tuple::_1 { + mean: chunk.each_ref().map(|x| x.mean), + dis_u_2: chunk.each_ref().map(|x| x.dis_u_2), + factor_ppc: chunk.each_ref().map(|x| x.factor_ppc), + factor_ip: chunk.each_ref().map(|x| x.factor_ip), + factor_err: chunk.each_ref().map(|x| x.factor_err), + payload: chunk.each_ref().map(|x| Some(x.payload)), + elements: remain, + }); + break; + } + (_, Some(w)) => { + let (left, right) = remain.split_at(std::cmp::min(w, remain.len())); + self.tape.tape_put(H0Tuple::_2 { + elements: left.to_vec(), + }); + remain = right.to_vec(); + } + (_, None) => self.tape.tape_move(), + } + } + } + } + pub fn into_inner(mut self) -> TapeWriter { + for x in self.branches { + self.tape.push(H0Tuple::_0 { + mean: x.mean, + dis_u_2: x.dis_u_2, + factor_ppc: x.factor_ppc, + factor_ip: x.factor_ip, + factor_err: x.factor_err, + payload: Some(x.payload), + elements: rabitq::pack_to_u64(&x.signs), + }); + } + self.tape + } +} + +pub fn read_h1_tape( + relation: impl RelationRead, + first: u32, + compute_block: impl Fn() -> A + Copy, + mut callback: impl FnMut(Distance, IndexPointer, u32), +) where + A: for<'a> Accessor1< + [u64; 2], + (&'a [f32; 32], &'a [f32; 32], &'a [f32; 32], &'a [f32; 32]), + Output = [Distance; 32], + >, +{ + assert!(first != u32::MAX); + let mut current = first; + let mut computing = None; + while current != u32::MAX { + let h1_guard = relation.read(current); + for i in 1..=h1_guard.len() { + let h1_tuple = h1_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h1_tuple { + H1TupleReader::_0(h1_tuple) => { + let mut compute = computing.take().unwrap_or_else(compute_block); + compute.push(h1_tuple.elements()); + let lowerbounds = compute.finish(h1_tuple.metadata()); + for i in 0..h1_tuple.len() { + callback( + lowerbounds[i as usize], + h1_tuple.mean()[i as usize], + h1_tuple.first()[i as usize], + ); + } + } + H1TupleReader::_1(h1_tuple) => { + let computing = computing.get_or_insert_with(compute_block); + computing.push(h1_tuple.elements()); + } + } + } + current = h1_guard.get_opaque().next; + } +} + +pub fn read_h0_tape( + relation: impl RelationRead, + first: u32, + compute_block: impl Fn() -> A + Copy, + compute_binary: impl Fn((f32, f32, f32, f32, &[u64])) -> Distance, + mut callback: impl FnMut(Distance, IndexPointer, NonZeroU64), +) where + A: for<'a> Accessor1< + [u64; 2], + (&'a [f32; 32], &'a [f32; 32], &'a [f32; 32], &'a [f32; 32]), + Output = [Distance; 32], + >, +{ + assert!(first != u32::MAX); + let mut current = first; + let mut computing = None; + while current != u32::MAX { + let h0_guard = relation.read(current); + for i in 1..=h0_guard.len() { + let h0_tuple = h0_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h0_tuple { + H0TupleReader::_0(h0_tuple) => { + let lowerbound = compute_binary(h0_tuple.code()); + if let Some(payload) = h0_tuple.payload() { + callback(lowerbound, h0_tuple.mean(), payload); + } + } + H0TupleReader::_1(h0_tuple) => { + let mut compute = computing.take().unwrap_or_else(compute_block); + compute.push(h0_tuple.elements()); + let lowerbounds = compute.finish(h0_tuple.metadata()); + for j in 0..32 { + if let Some(payload) = h0_tuple.payload()[j] { + callback(lowerbounds[j], h0_tuple.mean()[j], payload); + } + } + } + H0TupleReader::_2(h0_tuple) => { + let computing = computing.get_or_insert_with(compute_block); + computing.push(h0_tuple.elements()); + } + } + } + current = h0_guard.get_opaque().next; + } +} diff --git a/src/algorithm/tuples.rs b/src/algorithm/tuples.rs new file mode 100644 index 0000000..6ecef02 --- /dev/null +++ b/src/algorithm/tuples.rs @@ -0,0 +1,1205 @@ +use crate::algorithm::operator::Vector; +use std::num::{NonZeroU8, NonZeroU64}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; +use zerocopy_derive::{FromBytes, Immutable, IntoBytes, KnownLayout}; + +pub const ALIGN: usize = 8; +pub type Tag = u64; + +pub trait Tuple: 'static { + type Reader<'a>: TupleReader<'a, Tuple = Self>; + fn serialize(&self) -> Vec; +} + +pub trait WithWriter: Tuple { + type Writer<'a>: TupleWriter<'a, Tuple = Self>; +} + +pub trait TupleReader<'a>: Copy { + type Tuple: Tuple; + fn deserialize_ref(source: &'a [u8]) -> Self; +} + +pub trait TupleWriter<'a> { + type Tuple: Tuple; + fn deserialize_mut(source: &'a mut [u8]) -> Self; +} + +pub fn serialize(tuple: &T) -> Vec { + Tuple::serialize(tuple) +} + +pub fn read_tuple(source: &[u8]) -> T::Reader<'_> { + TupleReader::deserialize_ref(source) +} + +pub fn write_tuple(source: &mut [u8]) -> T::Writer<'_> { + TupleWriter::deserialize_mut(source) +} + +// meta tuple + +#[repr(C, align(8))] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct MetaTupleHeader { + dims: u32, + height_of_root: u32, + is_residual: Bool, + _padding_0: [ZeroU8; 3], + vectors_first: u32, + // raw vector + root_mean: IndexPointer, + // for meta tuple, it's pointers to next level + root_first: u32, + freepage_first: u32, +} + +pub struct MetaTuple { + pub dims: u32, + pub height_of_root: u32, + pub is_residual: bool, + pub vectors_first: u32, + pub root_mean: IndexPointer, + pub root_first: u32, + pub freepage_first: u32, +} + +impl Tuple for MetaTuple { + type Reader<'a> = MetaTupleReader<'a>; + + fn serialize(&self) -> Vec { + MetaTupleHeader { + dims: self.dims, + height_of_root: self.height_of_root, + is_residual: self.is_residual.into(), + _padding_0: Default::default(), + vectors_first: self.vectors_first, + root_mean: self.root_mean, + root_first: self.root_first, + freepage_first: self.freepage_first, + } + .as_bytes() + .to_vec() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct MetaTupleReader<'a> { + header: &'a MetaTupleHeader, +} + +impl<'a> TupleReader<'a> for MetaTupleReader<'a> { + type Tuple = MetaTuple; + fn deserialize_ref(source: &'a [u8]) -> Self { + let checker = RefChecker::new(source); + let header = checker.prefix(0); + Self { header } + } +} + +impl MetaTupleReader<'_> { + pub fn dims(self) -> u32 { + self.header.dims + } + pub fn height_of_root(self) -> u32 { + self.header.height_of_root + } + pub fn is_residual(self) -> bool { + self.header.is_residual.into() + } + pub fn vectors_first(self) -> u32 { + self.header.vectors_first + } + pub fn root_mean(self) -> IndexPointer { + self.header.root_mean + } + pub fn root_first(self) -> u32 { + self.header.root_first + } + pub fn freepage_first(self) -> u32 { + self.header.freepage_first + } +} + +// freepage tuple + +#[repr(C, align(8))] +#[derive(Debug, Clone, Copy, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct FreepageTupleHeader { + a: [u32; 1], + b: [u32; 32], + c: [u32; 32 * 32], + _padding_0: [ZeroU8; 4], +} + +const _: () = assert!(size_of::() == 4232); + +#[derive(Debug, Clone, PartialEq)] +pub struct FreepageTuple {} + +impl Tuple for FreepageTuple { + type Reader<'a> = FreepageTupleReader<'a>; + + fn serialize(&self) -> Vec { + FreepageTupleHeader { + a: std::array::from_fn(|_| 0), + b: std::array::from_fn(|_| 0), + c: std::array::from_fn(|_| 0), + _padding_0: Default::default(), + } + .as_bytes() + .to_vec() + } +} + +impl WithWriter for FreepageTuple { + type Writer<'a> = FreepageTupleWriter<'a>; +} + +#[derive(Debug, Clone, Copy)] +pub struct FreepageTupleReader<'a> { + #[allow(dead_code)] + header: &'a FreepageTupleHeader, +} + +impl<'a> TupleReader<'a> for FreepageTupleReader<'a> { + type Tuple = FreepageTuple; + + fn deserialize_ref(source: &'a [u8]) -> Self { + let checker = RefChecker::new(source); + let header = checker.prefix(0); + Self { header } + } +} + +pub struct FreepageTupleWriter<'a> { + header: &'a mut FreepageTupleHeader, +} + +impl<'a> TupleWriter<'a> for FreepageTupleWriter<'a> { + type Tuple = FreepageTuple; + + fn deserialize_mut(source: &'a mut [u8]) -> Self { + let mut checker = MutChecker::new(source); + let header = checker.prefix(0); + Self { header } + } +} + +impl FreepageTupleWriter<'_> { + pub fn mark(&mut self, i: usize) { + let c_i = i; + self.header.c[c_i / 32] |= 1 << (c_i % 32); + let b_i = i / 32; + self.header.b[b_i / 32] |= 1 << (b_i % 32); + let a_i = i / 32 / 32; + self.header.a[a_i / 32] |= 1 << (a_i % 32); + } + pub fn fetch(&mut self) -> Option { + if self.header.a[0].trailing_ones() == 32 { + return None; + } + let a_i = self.header.a[0].trailing_zeros() as usize; + let b_i = self.header.b[a_i].trailing_zeros() as usize + a_i * 32; + let c_i = self.header.c[b_i].trailing_zeros() as usize + b_i * 32; + self.header.c[c_i / 32] &= !(1 << (c_i % 32)); + if self.header.c[b_i] == 0 { + self.header.b[b_i / 32] &= !(1 << (b_i % 32)); + if self.header.b[a_i] == 0 { + self.header.a[a_i / 32] &= !(1 << (a_i % 32)); + } + } + Some(c_i) + } +} + +// vector tuple + +#[repr(C, align(8))] +#[derive(Debug, Clone, Copy, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct VectorTupleHeader0 { + payload: Option, + metadata_s: usize, + elements_s: usize, + elements_e: usize, +} + +#[repr(C, align(8))] +#[derive(Debug, Clone, Copy, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct VectorTupleHeader1 { + payload: Option, + pointer: IndexPointer, + elements_s: usize, + elements_e: usize, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum VectorTuple { + _0 { + payload: Option, + metadata: V::Metadata, + elements: Vec, + }, + _1 { + payload: Option, + pointer: IndexPointer, + elements: Vec, + }, +} + +impl Tuple for VectorTuple { + type Reader<'a> = VectorTupleReader<'a, V>; + + fn serialize(&self) -> Vec { + let mut buffer = Vec::::new(); + match self { + VectorTuple::_0 { + payload, + metadata, + elements, + } => { + buffer.extend((0 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat(0).take(size_of::())); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let metadata_s = buffer.len(); + buffer.extend(metadata.as_bytes()); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let elements_s = buffer.len(); + buffer.extend(elements.as_bytes()); + let elements_e = buffer.len(); + buffer[size_of::()..][..size_of::()].copy_from_slice( + VectorTupleHeader0 { + payload: *payload, + metadata_s, + elements_s, + elements_e, + } + .as_bytes(), + ); + } + VectorTuple::_1 { + payload, + pointer, + elements, + } => { + buffer.extend((1 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat(0).take(size_of::())); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let elements_s = buffer.len(); + buffer.extend(elements.as_bytes()); + let elements_e = buffer.len(); + buffer[size_of::()..][..size_of::()].copy_from_slice( + VectorTupleHeader1 { + payload: *payload, + pointer: *pointer, + elements_s, + elements_e, + } + .as_bytes(), + ); + } + } + buffer + } +} + +#[derive(Clone)] +pub struct VectorTupleReader0<'a, V: Vector> { + header: &'a VectorTupleHeader0, + metadata: &'a V::Metadata, + elements: &'a [V::Element], +} + +impl Copy for VectorTupleReader0<'_, V> {} + +#[derive(Clone)] +pub struct VectorTupleReader1<'a, V: Vector> { + header: &'a VectorTupleHeader1, + elements: &'a [V::Element], +} + +impl Copy for VectorTupleReader1<'_, V> {} + +#[derive(Clone)] +pub enum VectorTupleReader<'a, V: Vector> { + _0(VectorTupleReader0<'a, V>), + _1(VectorTupleReader1<'a, V>), +} + +impl Copy for VectorTupleReader<'_, V> {} + +impl<'a, V: Vector> TupleReader<'a> for VectorTupleReader<'a, V> { + type Tuple = VectorTuple; + + fn deserialize_ref(source: &'a [u8]) -> Self { + let tag = Tag::from_ne_bytes(std::array::from_fn(|i| source[i])); + match tag { + 0 => { + let checker = RefChecker::new(source); + let header: &VectorTupleHeader0 = checker.prefix(size_of::()); + let metadata = checker.prefix(header.metadata_s); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_0(VectorTupleReader0 { + header, + elements, + metadata, + }) + } + 1 => { + let checker = RefChecker::new(source); + let header: &VectorTupleHeader1 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_1(VectorTupleReader1 { header, elements }) + } + _ => panic!("bad bytes"), + } + } +} + +impl<'a, V: Vector> VectorTupleReader<'a, V> { + pub fn payload(self) -> Option { + match self { + VectorTupleReader::_0(this) => this.header.payload, + VectorTupleReader::_1(this) => this.header.payload, + } + } + pub fn elements(self) -> &'a [::Element] { + match self { + VectorTupleReader::_0(this) => this.elements, + VectorTupleReader::_1(this) => this.elements, + } + } + pub fn metadata_or_pointer(self) -> Result { + match self { + VectorTupleReader::_0(this) => Ok(*this.metadata), + VectorTupleReader::_1(this) => Err(this.header.pointer), + } + } +} + +// height1tuple + +#[repr(C, align(8))] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct H1TupleHeader0 { + mean: [IndexPointer; 32], + dis_u_2: [f32; 32], + factor_ppc: [f32; 32], + factor_ip: [f32; 32], + factor_err: [f32; 32], + first: [u32; 32], + len: u32, + _padding_0: [ZeroU8; 4], + elements_s: usize, + elements_e: usize, +} + +#[repr(C, align(8))] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct H1TupleHeader1 { + elements_s: usize, + elements_e: usize, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone, PartialEq)] +pub enum H1Tuple { + _0 { + mean: [IndexPointer; 32], + dis_u_2: [f32; 32], + factor_ppc: [f32; 32], + factor_ip: [f32; 32], + factor_err: [f32; 32], + first: [u32; 32], + len: u32, + elements: Vec<[u64; 2]>, + }, + _1 { + elements: Vec<[u64; 2]>, + }, +} + +impl H1Tuple { + pub fn fit_0(freespace: u16) -> Option { + let mut freespace = freespace as isize; + freespace -= size_of::() as isize; + freespace -= size_of::() as isize; + if freespace >= 0 { + Some(freespace as usize / size_of::<[u64; 2]>()) + } else { + None + } + } + pub fn fit_1(freespace: u16) -> Option { + let mut freespace = freespace as isize; + freespace -= size_of::() as isize; + freespace -= size_of::() as isize; + if freespace >= 0 { + Some(freespace as usize / size_of::<[u64; 2]>()) + } else { + None + } + } +} + +impl Tuple for H1Tuple { + type Reader<'a> = H1TupleReader<'a>; + + fn serialize(&self) -> Vec { + let mut buffer = Vec::::new(); + match self { + Self::_0 { + mean, + dis_u_2, + factor_ppc, + factor_ip, + factor_err, + first, + len, + elements, + } => { + buffer.extend((0 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat(0).take(size_of::())); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let elements_s = buffer.len(); + buffer.extend(elements.as_bytes()); + let elements_e = buffer.len(); + buffer[size_of::()..][..size_of::()].copy_from_slice( + H1TupleHeader0 { + mean: *mean, + dis_u_2: *dis_u_2, + factor_ppc: *factor_ppc, + factor_ip: *factor_ip, + factor_err: *factor_err, + first: *first, + len: *len, + _padding_0: Default::default(), + elements_s, + elements_e, + } + .as_bytes(), + ); + } + Self::_1 { elements } => { + buffer.extend((1 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat(0).take(size_of::())); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let elements_s = buffer.len(); + buffer.extend(elements.as_bytes()); + let elements_e = buffer.len(); + buffer[size_of::()..][..size_of::()].copy_from_slice( + H1TupleHeader1 { + elements_s, + elements_e, + } + .as_bytes(), + ); + } + } + buffer + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum H1TupleReader<'a> { + _0(H1TupleReader0<'a>), + _1(H1TupleReader1<'a>), +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct H1TupleReader0<'a> { + header: &'a H1TupleHeader0, + elements: &'a [[u64; 2]], +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct H1TupleReader1<'a> { + header: &'a H1TupleHeader1, + elements: &'a [[u64; 2]], +} + +impl<'a> TupleReader<'a> for H1TupleReader<'a> { + type Tuple = H1Tuple; + + fn deserialize_ref(source: &'a [u8]) -> Self { + let tag = Tag::from_ne_bytes(std::array::from_fn(|i| source[i])); + match tag { + 0 => { + let checker = RefChecker::new(source); + let header: &H1TupleHeader0 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_0(H1TupleReader0 { header, elements }) + } + 1 => { + let checker = RefChecker::new(source); + let header: &H1TupleHeader1 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_1(H1TupleReader1 { header, elements }) + } + _ => panic!("bad bytes"), + } + } +} + +impl<'a> H1TupleReader0<'a> { + pub fn len(self) -> u32 { + self.header.len + } + pub fn mean(self) -> &'a [IndexPointer] { + &self.header.mean[..self.header.len as usize] + } + pub fn metadata(self) -> (&'a [f32; 32], &'a [f32; 32], &'a [f32; 32], &'a [f32; 32]) { + ( + &self.header.dis_u_2, + &self.header.factor_ppc, + &self.header.factor_ip, + &self.header.factor_err, + ) + } + pub fn first(self) -> &'a [u32] { + &self.header.first[..self.header.len as usize] + } + pub fn elements(&self) -> &'a [[u64; 2]] { + self.elements + } +} + +impl<'a> H1TupleReader1<'a> { + pub fn elements(&self) -> &'a [[u64; 2]] { + self.elements + } +} + +#[repr(C, align(8))] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct JumpTupleHeader { + first: u32, + _padding_0: [ZeroU8; 4], +} + +#[derive(Debug, Clone, PartialEq)] +pub struct JumpTuple { + pub first: u32, +} + +impl Tuple for JumpTuple { + type Reader<'a> = JumpTupleReader<'a>; + + fn serialize(&self) -> Vec { + JumpTupleHeader { + first: self.first, + _padding_0: Default::default(), + } + .as_bytes() + .to_vec() + } +} + +impl WithWriter for JumpTuple { + type Writer<'a> = JumpTupleWriter<'a>; +} + +#[derive(Debug, Clone, Copy)] +pub struct JumpTupleReader<'a> { + header: &'a JumpTupleHeader, +} + +impl<'a> TupleReader<'a> for JumpTupleReader<'a> { + type Tuple = JumpTuple; + + fn deserialize_ref(source: &'a [u8]) -> Self { + let checker = RefChecker::new(source); + let header: &JumpTupleHeader = checker.prefix(0); + Self { header } + } +} + +impl JumpTupleReader<'_> { + pub fn first(self) -> u32 { + self.header.first + } +} + +#[derive(Debug)] +pub struct JumpTupleWriter<'a> { + header: &'a mut JumpTupleHeader, +} + +impl<'a> TupleWriter<'a> for JumpTupleWriter<'a> { + type Tuple = JumpTuple; + + fn deserialize_mut(source: &'a mut [u8]) -> Self { + let mut checker = MutChecker::new(source); + let header: &mut JumpTupleHeader = checker.prefix(0); + Self { header } + } +} + +impl JumpTupleWriter<'_> { + pub fn first(&mut self) -> &mut u32 { + &mut self.header.first + } +} + +#[repr(C, align(8))] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct H0TupleHeader0 { + mean: IndexPointer, + dis_u_2: f32, + factor_ppc: f32, + factor_ip: f32, + factor_err: f32, + payload: Option, + elements_s: usize, + elements_e: usize, +} + +#[repr(C, align(8))] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct H0TupleHeader1 { + mean: [IndexPointer; 32], + dis_u_2: [f32; 32], + factor_ppc: [f32; 32], + factor_ip: [f32; 32], + factor_err: [f32; 32], + payload: [Option; 32], + elements_s: usize, + elements_e: usize, +} + +#[repr(C, align(8))] +#[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] +struct H0TupleHeader2 { + elements_s: usize, + elements_e: usize, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone, PartialEq)] +pub enum H0Tuple { + _0 { + mean: IndexPointer, + dis_u_2: f32, + factor_ppc: f32, + factor_ip: f32, + factor_err: f32, + payload: Option, + elements: Vec, + }, + _1 { + mean: [IndexPointer; 32], + dis_u_2: [f32; 32], + factor_ppc: [f32; 32], + factor_ip: [f32; 32], + factor_err: [f32; 32], + payload: [Option; 32], + elements: Vec<[u64; 2]>, + }, + _2 { + elements: Vec<[u64; 2]>, + }, +} + +impl H0Tuple { + pub fn fit_1(freespace: u16) -> Option { + let mut freespace = freespace as isize; + freespace -= size_of::() as isize; + freespace -= size_of::() as isize; + if freespace >= 0 { + Some(freespace as usize / size_of::<[u64; 2]>()) + } else { + None + } + } + pub fn fit_2(freespace: u16) -> Option { + let mut freespace = freespace as isize; + freespace -= size_of::() as isize; + freespace -= size_of::() as isize; + if freespace >= 0 { + Some(freespace as usize / size_of::<[u64; 2]>()) + } else { + None + } + } +} + +impl Tuple for H0Tuple { + type Reader<'a> = H0TupleReader<'a>; + + fn serialize(&self) -> Vec { + let mut buffer = Vec::::new(); + match self { + H0Tuple::_0 { + mean, + dis_u_2, + factor_ppc, + factor_ip, + factor_err, + payload, + elements, + } => { + buffer.extend((0 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat(0).take(size_of::())); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let elements_s = buffer.len(); + buffer.extend(elements.as_bytes()); + let elements_e = buffer.len(); + buffer[size_of::()..][..size_of::()].copy_from_slice( + H0TupleHeader0 { + mean: *mean, + dis_u_2: *dis_u_2, + factor_ppc: *factor_ppc, + factor_ip: *factor_ip, + factor_err: *factor_err, + payload: *payload, + elements_s, + elements_e, + } + .as_bytes(), + ); + } + H0Tuple::_1 { + mean, + dis_u_2, + factor_ppc, + factor_ip, + factor_err, + payload, + elements, + } => { + buffer.extend((1 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat(0).take(size_of::())); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let elements_s = buffer.len(); + buffer.extend(elements.as_bytes()); + let elements_e = buffer.len(); + buffer[size_of::()..][..size_of::()].copy_from_slice( + H0TupleHeader1 { + mean: *mean, + dis_u_2: *dis_u_2, + factor_ppc: *factor_ppc, + factor_ip: *factor_ip, + factor_err: *factor_err, + payload: *payload, + elements_s, + elements_e, + } + .as_bytes(), + ); + } + Self::_2 { elements } => { + buffer.extend((2 as Tag).to_ne_bytes()); + buffer.extend(std::iter::repeat(0).take(size_of::())); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let elements_s = buffer.len(); + buffer.extend(elements.as_bytes()); + let elements_e = buffer.len(); + buffer[size_of::()..][..size_of::()].copy_from_slice( + H0TupleHeader2 { + elements_s, + elements_e, + } + .as_bytes(), + ); + } + } + buffer + } +} + +impl WithWriter for H0Tuple { + type Writer<'a> = H0TupleWriter<'a>; +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum H0TupleReader<'a> { + _0(H0TupleReader0<'a>), + _1(H0TupleReader1<'a>), + _2(H0TupleReader2<'a>), +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct H0TupleReader0<'a> { + header: &'a H0TupleHeader0, + elements: &'a [u64], +} + +impl<'a> H0TupleReader0<'a> { + pub fn mean(self) -> IndexPointer { + self.header.mean + } + pub fn code(self) -> (f32, f32, f32, f32, &'a [u64]) { + ( + self.header.dis_u_2, + self.header.factor_ppc, + self.header.factor_ip, + self.header.factor_err, + self.elements, + ) + } + pub fn payload(self) -> Option { + self.header.payload + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct H0TupleReader1<'a> { + header: &'a H0TupleHeader1, + elements: &'a [[u64; 2]], +} + +impl<'a> H0TupleReader1<'a> { + pub fn mean(self) -> &'a [IndexPointer; 32] { + &self.header.mean + } + pub fn metadata(self) -> (&'a [f32; 32], &'a [f32; 32], &'a [f32; 32], &'a [f32; 32]) { + ( + &self.header.dis_u_2, + &self.header.factor_ppc, + &self.header.factor_ip, + &self.header.factor_err, + ) + } + pub fn elements(self) -> &'a [[u64; 2]] { + self.elements + } + pub fn payload(self) -> &'a [Option; 32] { + &self.header.payload + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct H0TupleReader2<'a> { + header: &'a H0TupleHeader2, + elements: &'a [[u64; 2]], +} + +impl<'a> H0TupleReader2<'a> { + pub fn elements(self) -> &'a [[u64; 2]] { + self.elements + } +} + +impl<'a> TupleReader<'a> for H0TupleReader<'a> { + type Tuple = H0Tuple; + + fn deserialize_ref(source: &'a [u8]) -> Self { + let tag = Tag::from_ne_bytes(std::array::from_fn(|i| source[i])); + match tag { + 0 => { + let checker = RefChecker::new(source); + let header: &H0TupleHeader0 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_0(H0TupleReader0 { header, elements }) + } + 1 => { + let checker = RefChecker::new(source); + let header: &H0TupleHeader1 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_1(H0TupleReader1 { header, elements }) + } + 2 => { + let checker = RefChecker::new(source); + let header: &H0TupleHeader2 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_2(H0TupleReader2 { header, elements }) + } + _ => panic!("bad bytes"), + } + } +} + +#[derive(Debug)] +pub enum H0TupleWriter<'a> { + _0(H0TupleWriter0<'a>), + _1(H0TupleWriter1<'a>), + #[allow(dead_code)] + _2(H0TupleWriter2<'a>), +} + +#[derive(Debug)] +pub struct H0TupleWriter0<'a> { + header: &'a mut H0TupleHeader0, + #[allow(dead_code)] + elements: &'a mut [u64], +} + +#[derive(Debug)] +pub struct H0TupleWriter1<'a> { + header: &'a mut H0TupleHeader1, + #[allow(dead_code)] + elements: &'a mut [[u64; 2]], +} + +#[derive(Debug)] +pub struct H0TupleWriter2<'a> { + #[allow(dead_code)] + header: &'a mut H0TupleHeader2, + #[allow(dead_code)] + elements: &'a mut [[u64; 2]], +} + +impl<'a> TupleWriter<'a> for H0TupleWriter<'a> { + type Tuple = H0Tuple; + + fn deserialize_mut(source: &'a mut [u8]) -> Self { + let tag = Tag::from_ne_bytes(std::array::from_fn(|i| source[i])); + match tag { + 0 => { + let mut checker = MutChecker::new(source); + let header: &mut H0TupleHeader0 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_0(H0TupleWriter0 { header, elements }) + } + 1 => { + let mut checker = MutChecker::new(source); + let header: &mut H0TupleHeader1 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_1(H0TupleWriter1 { header, elements }) + } + 2 => { + let mut checker = MutChecker::new(source); + let header: &mut H0TupleHeader2 = checker.prefix(size_of::()); + let elements = checker.bytes(header.elements_s, header.elements_e); + Self::_2(H0TupleWriter2 { header, elements }) + } + _ => panic!("bad bytes"), + } + } +} + +impl H0TupleWriter0<'_> { + pub fn payload(&mut self) -> &mut Option { + &mut self.header.payload + } +} + +impl H0TupleWriter1<'_> { + pub fn payload(&mut self) -> &mut [Option; 32] { + &mut self.header.payload + } +} + +pub const fn pointer_to_pair(pointer: IndexPointer) -> (u32, u16) { + let value = pointer.0; + (((value >> 16) & 0xffffffff) as u32, (value & 0xffff) as u16) +} + +pub const fn pair_to_pointer(pair: (u32, u16)) -> IndexPointer { + let mut value = 0; + value |= (pair.0 as u64) << 16; + value |= pair.1 as u64; + IndexPointer(value) +} + +#[allow(dead_code)] +const fn soundness_check(a: (u32, u16)) { + let b = pair_to_pointer(a); + let c = pointer_to_pair(b); + assert!(a.0 == c.0); + assert!(a.1 == c.1); +} + +const _: () = soundness_check((111, 222)); + +#[repr(transparent)] +#[derive( + Debug, + Default, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + IntoBytes, + FromBytes, + Immutable, + KnownLayout, +)] +pub struct ZeroU8(Option); + +#[repr(transparent)] +#[derive( + Debug, Default, Clone, Copy, PartialEq, Eq, Hash, IntoBytes, FromBytes, Immutable, KnownLayout, +)] +pub struct IndexPointer(pub u64); + +#[repr(transparent)] +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + IntoBytes, + FromBytes, + Immutable, + KnownLayout, +)] +pub struct Bool(u8); + +impl Bool { + pub const FALSE: Self = Self(0); + pub const TRUE: Self = Self(1); +} + +impl From for bool { + fn from(value: Bool) -> Self { + value != Bool::FALSE + } +} + +impl From for Bool { + fn from(value: bool) -> Self { + if value { Self::TRUE } else { Self::FALSE } + } +} + +pub struct RefChecker<'a> { + bytes: &'a [u8], +} + +impl<'a> RefChecker<'a> { + pub fn new(bytes: &'a [u8]) -> Self { + Self { bytes } + } + pub fn prefix( + &self, + s: usize, + ) -> &'a T { + let start = s; + let end = s + size_of::(); + let bytes = &self.bytes[start..end]; + FromBytes::ref_from_bytes(bytes).expect("bad bytes") + } + pub fn bytes( + &self, + s: usize, + e: usize, + ) -> &'a T { + let start = s; + let end = e; + let bytes = &self.bytes[start..end]; + FromBytes::ref_from_bytes(bytes).expect("bad bytes") + } +} + +pub struct MutChecker<'a> { + flag: Vec, + bytes: &'a mut [u8], +} + +impl<'a> MutChecker<'a> { + pub fn new(bytes: &'a mut [u8]) -> Self { + Self { + flag: vec![0u64; bytes.len().div_ceil(64)], + bytes, + } + } + pub fn prefix( + &mut self, + s: usize, + ) -> &'a mut T { + let start = s; + let end = s + size_of::(); + if !(start <= end && end <= self.bytes.len()) { + panic!("bad bytes"); + } + for i in start..end { + if (self.flag[i / 64] & (1 << (i % 64))) != 0 { + panic!("bad bytes"); + } else { + self.flag[i / 64] |= 1 << (i % 64); + } + } + let bytes = unsafe { + std::slice::from_raw_parts_mut(self.bytes.as_mut_ptr().add(start), end - start) + }; + FromBytes::mut_from_bytes(bytes).expect("bad bytes") + } + pub fn bytes( + &mut self, + s: usize, + e: usize, + ) -> &'a mut T { + let start = s; + let end = e; + if !(start <= end && end <= self.bytes.len()) { + panic!("bad bytes"); + } + for i in start..end { + if (self.flag[i / 64] & (1 << (i % 64))) != 0 { + panic!("bad bytes"); + } else { + self.flag[i / 64] |= 1 << (i % 64); + } + } + let bytes = unsafe { + std::slice::from_raw_parts_mut(self.bytes.as_mut_ptr().add(start), end - start) + }; + FromBytes::mut_from_bytes(bytes).expect("bad bytes") + } +} + +// this test only passes if `MIRIFLAGS="-Zmiri-tree-borrows"` is set +#[test] +fn aliasing_test() { + #[repr(C, align(8))] + #[derive(Debug, Clone, PartialEq, FromBytes, IntoBytes, Immutable, KnownLayout)] + struct ExampleHeader { + elements_s: usize, + elements_e: usize, + } + let serialized = { + let elements = (0u32..1111).collect::>(); + let mut buffer = Vec::::new(); + buffer.extend(std::iter::repeat(0).take(size_of::())); + while buffer.len() % ALIGN != 0 { + buffer.push(0); + } + let elements_s = buffer.len(); + buffer.extend(elements.as_bytes()); + let elements_e = buffer.len(); + buffer[..size_of::()].copy_from_slice( + ExampleHeader { + elements_s, + elements_e, + } + .as_bytes(), + ); + buffer + }; + let mut source = vec![0u64; serialized.len().next_multiple_of(8)]; + source.as_mut_bytes()[..serialized.len()].copy_from_slice(&serialized); + let deserialized = { + let mut checker = MutChecker::new(source.as_mut_bytes()); + let header: &mut ExampleHeader = checker.prefix(0); + let elements: &mut [u32] = checker.bytes(header.elements_s, header.elements_e); + (header, elements) + }; + assert_eq!( + deserialized.1, + (0u32..1111).collect::>().as_slice() + ); +} diff --git a/src/algorithm/vacuum.rs b/src/algorithm/vacuum.rs new file mode 100644 index 0000000..1736625 --- /dev/null +++ b/src/algorithm/vacuum.rs @@ -0,0 +1,311 @@ +use crate::algorithm::freepages; +use crate::algorithm::operator::Operator; +use crate::algorithm::tape::*; +use crate::algorithm::tuples::*; +use crate::algorithm::{Page, RelationWrite}; +use crate::utils::pipe::Pipe; +use simd::fast_scan::unpack; +use std::num::NonZeroU64; + +pub fn bulkdelete( + relation: impl RelationWrite, + delay: impl Fn(), + callback: impl Fn(NonZeroU64) -> bool, +) { + let meta_guard = relation.read(0); + let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); + let height_of_root = meta_tuple.height_of_root(); + let root_first = meta_tuple.root_first(); + let vectors_first = meta_tuple.vectors_first(); + drop(meta_guard); + { + type State = Vec; + let mut state: State = vec![root_first]; + let step = |state: State| { + let mut results = Vec::new(); + for first in state { + let mut current = first; + while current != u32::MAX { + let h1_guard = relation.read(current); + for i in 1..=h1_guard.len() { + let h1_tuple = h1_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h1_tuple { + H1TupleReader::_0(h1_tuple) => { + for first in h1_tuple.first().iter().copied() { + results.push(first); + } + } + H1TupleReader::_1(_) => (), + } + } + current = h1_guard.get_opaque().next; + } + } + results + }; + for _ in (1..height_of_root).rev() { + state = step(state); + } + for first in state { + let jump_guard = relation.read(first); + let jump_tuple = jump_guard + .get(1) + .expect("data corruption") + .pipe(read_tuple::); + let first = jump_tuple.first(); + let mut current = first; + while current != u32::MAX { + delay(); + let read = relation.read(current); + let flag = 'flag: { + for i in 1..=read.len() { + let h0_tuple = read + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h0_tuple { + H0TupleReader::_0(h0_tuple) => { + let p = h0_tuple.payload(); + if let Some(payload) = p { + if callback(payload) { + break 'flag true; + } + } + } + H0TupleReader::_1(h0_tuple) => { + let p = h0_tuple.payload(); + for j in 0..32 { + if let Some(payload) = p[j] { + if callback(payload) { + break 'flag true; + } + } + } + } + H0TupleReader::_2(_) => (), + } + } + false + }; + if flag { + drop(read); + let mut write = relation.write(current, false); + for i in 1..=write.len() { + let h0_tuple = write + .get_mut(i) + .expect("data corruption") + .pipe(write_tuple::); + match h0_tuple { + H0TupleWriter::_0(mut h0_tuple) => { + let p = h0_tuple.payload(); + if let Some(payload) = *p { + if callback(payload) { + *p = None; + } + } + } + H0TupleWriter::_1(mut h0_tuple) => { + let p = h0_tuple.payload(); + for j in 0..32 { + if let Some(payload) = p[j] { + if callback(payload) { + p[j] = None; + } + } + } + } + H0TupleWriter::_2(_) => (), + } + } + current = write.get_opaque().next; + } else { + current = read.get_opaque().next; + } + } + } + } + { + let first = vectors_first; + let mut current = first; + while current != u32::MAX { + delay(); + let read = relation.read(current); + let flag = 'flag: { + for i in 1..=read.len() { + if let Some(vector_bytes) = read.get(i) { + let vector_tuple = vector_bytes.pipe(read_tuple::>); + let p = vector_tuple.payload(); + if let Some(payload) = p { + if callback(payload) { + break 'flag true; + } + } + } + } + false + }; + if flag { + drop(read); + let mut write = relation.write(current, true); + for i in 1..=write.len() { + if let Some(vector_bytes) = write.get(i) { + let vector_tuple = vector_bytes.pipe(read_tuple::>); + let p = vector_tuple.payload(); + if let Some(payload) = p { + if callback(payload) { + write.free(i); + } + } + }; + } + current = write.get_opaque().next; + } else { + current = read.get_opaque().next; + } + } + } +} + +pub fn maintain(relation: impl RelationWrite + Clone, delay: impl Fn()) { + let meta_guard = relation.read(0); + let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::); + let dims = meta_tuple.dims(); + let height_of_root = meta_tuple.height_of_root(); + let root_first = meta_tuple.root_first(); + let freepage_first = meta_tuple.freepage_first(); + drop(meta_guard); + + let firsts = { + type State = Vec; + let mut state: State = vec![root_first]; + let step = |state: State| { + let mut results = Vec::new(); + for first in state { + let mut current = first; + while current != u32::MAX { + delay(); + let h1_guard = relation.read(current); + for i in 1..=h1_guard.len() { + let h1_tuple = h1_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h1_tuple { + H1TupleReader::_0(h1_tuple) => { + for first in h1_tuple.first().iter().copied() { + results.push(first); + } + } + H1TupleReader::_1(_) => (), + } + } + current = h1_guard.get_opaque().next; + } + } + results + }; + for _ in (1..height_of_root).rev() { + state = step(state); + } + state + }; + + for first in firsts { + let mut jump_guard = relation.write(first, false); + let mut jump_tuple = jump_guard + .get_mut(1) + .expect("data corruption") + .pipe(write_tuple::); + + let mut tape = H0Tape::<_, _>::create(|| { + if let Some(id) = freepages::fetch(relation.clone(), freepage_first) { + let mut write = relation.write(id, false); + write.clear(); + write + } else { + relation.extend(false) + } + }); + + let mut trace = Vec::new(); + + let first = *jump_tuple.first(); + let mut current = first; + let mut computing = None; + while current != u32::MAX { + delay(); + trace.push(current); + let h0_guard = relation.read(current); + for i in 1..=h0_guard.len() { + let h0_tuple = h0_guard + .get(i) + .expect("data corruption") + .pipe(read_tuple::); + match h0_tuple { + H0TupleReader::_0(h0_tuple) => { + if let Some(payload) = h0_tuple.payload() { + tape.push(H0BranchWriter { + mean: h0_tuple.mean(), + dis_u_2: h0_tuple.code().0, + factor_ppc: h0_tuple.code().1, + factor_ip: h0_tuple.code().2, + factor_err: h0_tuple.code().3, + signs: h0_tuple + .code() + .4 + .iter() + .flat_map(|x| { + std::array::from_fn::<_, 64, _>(|i| *x & (1 << i) != 0) + }) + .take(dims as _) + .collect::>(), + payload, + }); + } + } + H0TupleReader::_1(h0_tuple) => { + let computing = &mut computing.take().unwrap_or_else(Vec::new); + computing.extend_from_slice(h0_tuple.elements()); + let unpacked = unpack(computing); + for j in 0..32 { + if let Some(payload) = h0_tuple.payload()[j] { + tape.push(H0BranchWriter { + mean: h0_tuple.mean()[j], + dis_u_2: h0_tuple.metadata().0[j], + factor_ppc: h0_tuple.metadata().1[j], + factor_ip: h0_tuple.metadata().2[j], + factor_err: h0_tuple.metadata().3[j], + signs: unpacked[j] + .iter() + .flat_map(|&x| { + [x & 1 != 0, x & 2 != 0, x & 4 != 0, x & 8 != 0] + }) + .collect(), + payload, + }); + } + } + } + H0TupleReader::_2(h0_tuple) => { + let computing = computing.get_or_insert_with(Vec::new); + computing.extend_from_slice(h0_tuple.elements()); + } + } + } + current = h0_guard.get_opaque().next; + drop(h0_guard); + } + + let tape = tape.into_inner(); + let new = tape.first(); + drop(tape); + + *jump_tuple.first() = new; + drop(jump_guard); + + freepages::mark(relation.clone(), freepage_first, &trace); + } +} diff --git a/src/algorithm/vectors.rs b/src/algorithm/vectors.rs new file mode 100644 index 0000000..d71499b --- /dev/null +++ b/src/algorithm/vectors.rs @@ -0,0 +1,133 @@ +use crate::algorithm::operator::*; +use crate::algorithm::tuples::*; +use crate::algorithm::{Page, PageGuard, RelationRead, RelationWrite}; +use crate::utils::pipe::Pipe; +use std::num::NonZeroU64; +use vector::VectorOwned; + +pub fn vector_access_1< + O: Operator, + A: Accessor1<::Element, ::Metadata>, +>( + relation: impl RelationRead, + mean: IndexPointer, + accessor: A, +) -> A::Output { + let mut cursor = Err(mean); + let mut result = accessor; + while let Err(mean) = cursor.map_err(pointer_to_pair) { + let vector_guard = relation.read(mean.0); + let vector_tuple = vector_guard + .get(mean.1) + .expect("data corruption") + .pipe(read_tuple::>); + if vector_tuple.payload().is_some() { + panic!("data corruption"); + } + result.push(vector_tuple.elements()); + cursor = vector_tuple.metadata_or_pointer(); + } + result.finish(cursor.expect("data corruption")) +} + +pub fn vector_access_0< + O: Operator, + A: Accessor1<::Element, ::Metadata>, +>( + relation: impl RelationRead, + mean: IndexPointer, + payload: NonZeroU64, + accessor: A, +) -> Option { + let mut cursor = Err(mean); + let mut result = accessor; + while let Err(mean) = cursor.map_err(pointer_to_pair) { + let vector_guard = relation.read(mean.0); + let vector_tuple = vector_guard + .get(mean.1)? + .pipe(read_tuple::>); + if vector_tuple.payload().is_none() { + panic!("data corruption"); + } + if vector_tuple.payload() != Some(payload) { + return None; + } + result.push(vector_tuple.elements()); + cursor = vector_tuple.metadata_or_pointer(); + } + Some(result.finish(cursor.ok()?)) +} + +pub fn vector_append( + relation: impl RelationWrite + Clone, + vectors_first: u32, + vector: ::Borrowed<'_>, + payload: NonZeroU64, +) -> IndexPointer { + fn append(relation: impl RelationWrite, first: u32, bytes: &[u8]) -> IndexPointer { + if let Some(mut write) = relation.search(bytes.len()) { + let i = write.alloc(bytes).unwrap(); + return pair_to_pointer((write.id(), i)); + } + assert!(first != u32::MAX); + let mut current = first; + loop { + let read = relation.read(current); + if read.freespace() as usize >= bytes.len() || read.get_opaque().next == u32::MAX { + drop(read); + let mut write = relation.write(current, true); + if let Some(i) = write.alloc(bytes) { + return pair_to_pointer((current, i)); + } + if write.get_opaque().next == u32::MAX { + let mut extend = relation.extend(true); + write.get_opaque_mut().next = extend.id(); + drop(write); + if let Some(i) = extend.alloc(bytes) { + let result = (extend.id(), i); + drop(extend); + let mut past = relation.write(first, true); + let skip = &mut past.get_opaque_mut().skip; + assert!(*skip != u32::MAX); + *skip = std::cmp::max(*skip, result.0); + return pair_to_pointer(result); + } else { + panic!("a tuple cannot even be fit in a fresh page"); + } + } + if current == first && write.get_opaque().skip != first { + current = write.get_opaque().skip; + } else { + current = write.get_opaque().next; + } + } else { + if current == first && read.get_opaque().skip != first { + current = read.get_opaque().skip; + } else { + current = read.get_opaque().next; + } + } + } + } + let (metadata, slices) = O::Vector::vector_split(vector); + let mut chain = Ok(metadata); + for i in (0..slices.len()).rev() { + chain = Err(append( + relation.clone(), + vectors_first, + &serialize::>(&match chain { + Ok(metadata) => VectorTuple::_0 { + elements: slices[i].to_vec(), + payload: Some(payload), + metadata, + }, + Err(pointer) => VectorTuple::_1 { + elements: slices[i].to_vec(), + payload: Some(payload), + pointer, + }, + }), + )); + } + chain.err().unwrap() +} diff --git a/src/vchordrq/gucs/executing.rs b/src/gucs/executing.rs similarity index 100% rename from src/vchordrq/gucs/executing.rs rename to src/gucs/executing.rs diff --git a/src/vchordrq/gucs/mod.rs b/src/gucs/mod.rs similarity index 100% rename from src/vchordrq/gucs/mod.rs rename to src/gucs/mod.rs diff --git a/src/vchordrq/gucs/prewarm.rs b/src/gucs/prewarm.rs similarity index 100% rename from src/vchordrq/gucs/prewarm.rs rename to src/gucs/prewarm.rs diff --git a/src/vchordrq/index/am.rs b/src/index/am.rs similarity index 67% rename from src/vchordrq/index/am.rs rename to src/index/am.rs index 3d19a49..5db07d2 100644 --- a/src/vchordrq/index/am.rs +++ b/src/index/am.rs @@ -1,12 +1,13 @@ +use crate::algorithm; +use crate::algorithm::build::{HeapRelation, Reporter}; +use crate::algorithm::operator::{Dot, L2, Op}; +use crate::algorithm::operator::{Operator, Vector}; +use crate::index::am_options::{Opfamily, Reloption}; +use crate::index::am_scan::Scanner; +use crate::index::utils::{ctid_to_pointer, pointer_to_ctid}; +use crate::index::{am_options, am_scan}; use crate::postgres::PostgresRelation; -use crate::vchordrq::algorithm; -use crate::vchordrq::algorithm::build::{HeapRelation, Reporter}; -use crate::vchordrq::algorithm::tuples::Vector; -use crate::vchordrq::index::am_options::{Opfamily, Reloption}; -use crate::vchordrq::index::am_scan::Scanner; -use crate::vchordrq::index::utils::{ctid_to_pointer, pointer_to_ctid}; -use crate::vchordrq::index::{am_options, am_scan}; -use crate::vchordrq::types::VectorKind; +use crate::types::{DistanceKind, VectorKind}; use half::f16; use pgrx::datum::Internal; use pgrx::pg_sys::Datum; @@ -167,17 +168,17 @@ pub unsafe extern "C" fn ambuild( index_info: *mut pgrx::pg_sys::IndexInfo, opfamily: Opfamily, } - impl HeapRelation for Heap { + impl HeapRelation for Heap { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((NonZeroU64, V)), + F: FnMut((NonZeroU64, O::Vector)), { pub struct State<'a, F> { pub this: &'a Heap, pub callback: F, } #[pgrx::pg_guard] - unsafe extern "C" fn call( + unsafe extern "C" fn call( _index: pgrx::pg_sys::Relation, ctid: pgrx::pg_sys::ItemPointer, values: *mut Datum, @@ -185,14 +186,14 @@ pub unsafe extern "C" fn ambuild( _tuple_is_alive: bool, state: *mut core::ffi::c_void, ) where - F: FnMut((NonZeroU64, V)), + F: FnMut((NonZeroU64, O::Vector)), { let state = unsafe { &mut *state.cast::>() }; let opfamily = state.this.opfamily; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; let pointer = unsafe { ctid_to_pointer(ctid.read()) }; if let Some(vector) = vector { - (state.callback)((pointer, V::from_owned(vector))); + (state.callback)((pointer, O::Vector::from_owned(vector))); } } let table_am = unsafe { &*(*self.heap).rd_tableam }; @@ -210,7 +211,7 @@ pub unsafe extern "C" fn ambuild( progress, 0, pgrx::pg_sys::InvalidBlockNumber, - Some(call::), + Some(call::), (&mut state) as *mut State as *mut _, std::ptr::null_mut(), ); @@ -243,21 +244,43 @@ pub unsafe extern "C" fn ambuild( }; let mut reporter = PgReporter {}; let index_relation = unsafe { PostgresRelation::new(index) }; - match opfamily.vector_kind() { - VectorKind::Vecf32 => algorithm::build::build::, Heap, _>( - vector_options, - vchordrq_options, - heap_relation.clone(), - index_relation.clone(), - reporter.clone(), - ), - VectorKind::Vecf16 => algorithm::build::build::, Heap, _>( - vector_options, - vchordrq_options, - heap_relation.clone(), - index_relation.clone(), - reporter.clone(), - ), + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + algorithm::build::build::, L2>, Heap, _>( + vector_options, + vchordrq_options, + heap_relation.clone(), + index_relation.clone(), + reporter.clone(), + ) + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + algorithm::build::build::, Dot>, Heap, _>( + vector_options, + vchordrq_options, + heap_relation.clone(), + index_relation.clone(), + reporter.clone(), + ) + } + (VectorKind::Vecf16, DistanceKind::L2) => { + algorithm::build::build::, L2>, Heap, _>( + vector_options, + vchordrq_options, + heap_relation.clone(), + index_relation.clone(), + reporter.clone(), + ) + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + algorithm::build::build::, Dot>, Heap, _>( + vector_options, + vchordrq_options, + heap_relation.clone(), + index_relation.clone(), + reporter.clone(), + ) + } } if let Some(leader) = unsafe { VchordrqLeader::enter(heap, index, (*index_info).ii_Concurrent) } { @@ -290,35 +313,61 @@ pub unsafe extern "C" fn ambuild( let mut indtuples = 0; reporter.tuples_done(indtuples); let relation = unsafe { PostgresRelation::new(index) }; - match opfamily.vector_kind() { - VectorKind::Vecf32 => { - HeapRelation::>::traverse( + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + HeapRelation::, L2>>::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::, L2>>( + relation.clone(), + pointer, + vector, + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }, + ); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + HeapRelation::, Dot>>::traverse( &heap_relation, true, |(pointer, vector)| { - algorithm::insert::insert::>( + algorithm::insert::insert::, Dot>>( relation.clone(), pointer, vector, - opfamily.distance_kind(), - true, ); indtuples += 1; reporter.tuples_done(indtuples); }, ); } - VectorKind::Vecf16 => { - HeapRelation::>::traverse( + (VectorKind::Vecf16, DistanceKind::L2) => { + HeapRelation::, L2>>::traverse( &heap_relation, true, |(pointer, vector)| { - algorithm::insert::insert::>( + algorithm::insert::insert::, L2>>( + relation.clone(), + pointer, + vector, + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }, + ); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + HeapRelation::, Dot>>::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::, Dot>>( relation.clone(), pointer, vector, - opfamily.distance_kind(), - true, ); indtuples += 1; reporter.tuples_done(indtuples); @@ -327,6 +376,28 @@ pub unsafe extern "C" fn ambuild( } } } + let relation = unsafe { PostgresRelation::new(index) }; + let delay = || { + pgrx::check_for_interrupts!(); + }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + type O = Op, L2>; + algorithm::vacuum::maintain::(relation, delay); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + type O = Op, Dot>; + algorithm::vacuum::maintain::(relation, delay); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + type O = Op, L2>; + algorithm::vacuum::maintain::(relation, delay); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + type O = Op, Dot>; + algorithm::vacuum::maintain::(relation, delay); + } + } unsafe { pgrx::pgbox::PgBox::::alloc0().into_pg() } } @@ -572,17 +643,17 @@ unsafe fn parallel_build( opfamily: Opfamily, scan: *mut pgrx::pg_sys::TableScanDescData, } - impl HeapRelation for Heap { + impl HeapRelation for Heap { fn traverse(&self, progress: bool, callback: F) where - F: FnMut((NonZeroU64, V)), + F: FnMut((NonZeroU64, O::Vector)), { pub struct State<'a, F> { pub this: &'a Heap, pub callback: F, } #[pgrx::pg_guard] - unsafe extern "C" fn call( + unsafe extern "C" fn call( _index: pgrx::pg_sys::Relation, ctid: pgrx::pg_sys::ItemPointer, values: *mut Datum, @@ -590,14 +661,14 @@ unsafe fn parallel_build( _tuple_is_alive: bool, state: *mut core::ffi::c_void, ) where - F: FnMut((NonZeroU64, V)), + F: FnMut((NonZeroU64, O::Vector)), { let state = unsafe { &mut *state.cast::>() }; let opfamily = state.this.opfamily; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; let pointer = unsafe { ctid_to_pointer(ctid.read()) }; if let Some(vector) = vector { - (state.callback)((pointer, V::from_owned(vector))); + (state.callback)((pointer, O::Vector::from_owned(vector))); } } let table_am = unsafe { &*(*self.heap).rd_tableam }; @@ -615,7 +686,7 @@ unsafe fn parallel_build( progress, 0, pgrx::pg_sys::InvalidBlockNumber, - Some(call::), + Some(call::), (&mut state) as *mut State as *mut _, self.scan, ); @@ -638,52 +709,106 @@ unsafe fn parallel_build( opfamily, scan, }; - match opfamily.vector_kind() { - VectorKind::Vecf32 => { - HeapRelation::>::traverse(&heap_relation, true, |(pointer, vector)| { - algorithm::insert::insert::>( - index_relation.clone(), - pointer, - vector, - opfamily.distance_kind(), - true, - ); - unsafe { - let indtuples; - { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); - (*vchordrqshared).indtuples += 1; - indtuples = (*vchordrqshared).indtuples; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + HeapRelation::, L2>>::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::, L2>>( + index_relation.clone(), + pointer, + vector, + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } } - if let Some(reporter) = reporter.as_mut() { - reporter.tuples_done(indtuples); + }, + ); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + HeapRelation::, Dot>>::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::, Dot>>( + index_relation.clone(), + pointer, + vector, + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } } - } - }); + }, + ); } - VectorKind::Vecf16 => { - HeapRelation::>::traverse(&heap_relation, true, |(pointer, vector)| { - algorithm::insert::insert::>( - index_relation.clone(), - pointer, - vector, - opfamily.distance_kind(), - true, - ); - unsafe { - let indtuples; - { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); - (*vchordrqshared).indtuples += 1; - indtuples = (*vchordrqshared).indtuples; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + (VectorKind::Vecf16, DistanceKind::L2) => { + HeapRelation::, L2>>::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::, L2>>( + index_relation.clone(), + pointer, + vector, + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } } - if let Some(reporter) = reporter.as_mut() { - reporter.tuples_done(indtuples); + }, + ); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + HeapRelation::, Dot>>::traverse( + &heap_relation, + true, + |(pointer, vector)| { + algorithm::insert::insert::, Dot>>( + index_relation.clone(), + pointer, + vector, + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqshared).mutex); + (*vchordrqshared).indtuples += 1; + indtuples = (*vchordrqshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } } - } - }); + }, + ); } } unsafe { @@ -714,21 +839,35 @@ pub unsafe extern "C" fn aminsert( let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; if let Some(vector) = vector { let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); - match opfamily.vector_kind() { - VectorKind::Vecf32 => algorithm::insert::insert::>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - opfamily.distance_kind(), - false, - ), - VectorKind::Vecf16 => algorithm::insert::insert::>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - opfamily.distance_kind(), - false, - ), + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + algorithm::insert::insert::, L2>>( + unsafe { PostgresRelation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + ) + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + algorithm::insert::insert::, Dot>>( + unsafe { PostgresRelation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + ) + } + (VectorKind::Vecf16, DistanceKind::L2) => { + algorithm::insert::insert::, L2>>( + unsafe { PostgresRelation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + ) + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + algorithm::insert::insert::, Dot>>( + unsafe { PostgresRelation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + ) + } } } false @@ -750,21 +889,35 @@ pub unsafe extern "C" fn aminsert( let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; if let Some(vector) = vector { let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); - match opfamily.vector_kind() { - VectorKind::Vecf32 => algorithm::insert::insert::>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - opfamily.distance_kind(), - false, - ), - VectorKind::Vecf16 => algorithm::insert::insert::>( - unsafe { PostgresRelation::new(index) }, - pointer, - VectOwned::::from_owned(vector), - opfamily.distance_kind(), - false, - ), + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + algorithm::insert::insert::, L2>>( + unsafe { PostgresRelation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + ) + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + algorithm::insert::insert::, Dot>>( + unsafe { PostgresRelation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + ) + } + (VectorKind::Vecf16, DistanceKind::L2) => { + algorithm::insert::insert::, L2>>( + unsafe { PostgresRelation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + ) + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + algorithm::insert::insert::, Dot>>( + unsafe { PostgresRelation::new(index) }, + pointer, + VectOwned::::from_owned(vector), + ) + } } } false @@ -893,29 +1046,58 @@ pub unsafe extern "C" fn ambulkdelete( let opfamily = unsafe { am_options::opfamily((*info).index) }; let callback = callback.unwrap(); let callback = |p: NonZeroU64| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; - match opfamily.vector_kind() { - VectorKind::Vecf32 => algorithm::vacuum::vacuum::>( - unsafe { PostgresRelation::new((*info).index) }, - || unsafe { - pgrx::pg_sys::vacuum_delay_point(); - }, - callback, - ), - VectorKind::Vecf16 => algorithm::vacuum::vacuum::>( - unsafe { PostgresRelation::new((*info).index) }, - || unsafe { - pgrx::pg_sys::vacuum_delay_point(); - }, - callback, - ), + let index = unsafe { PostgresRelation::new((*info).index) }; + let delay = || unsafe { + pgrx::pg_sys::vacuum_delay_point(); + }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + type O = Op, L2>; + algorithm::vacuum::bulkdelete::(index.clone(), delay, callback); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + type O = Op, Dot>; + algorithm::vacuum::bulkdelete::(index.clone(), delay, callback); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + type O = Op, L2>; + algorithm::vacuum::bulkdelete::(index.clone(), delay, callback); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + type O = Op, Dot>; + algorithm::vacuum::bulkdelete::(index.clone(), delay, callback); + } } stats } #[pgrx::pg_guard] pub unsafe extern "C" fn amvacuumcleanup( - _info: *mut pgrx::pg_sys::IndexVacuumInfo, + info: *mut pgrx::pg_sys::IndexVacuumInfo, _stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, ) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { + let opfamily = unsafe { am_options::opfamily((*info).index) }; + let index = unsafe { PostgresRelation::new((*info).index) }; + let delay = || unsafe { + pgrx::pg_sys::vacuum_delay_point(); + }; + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + type O = Op, L2>; + algorithm::vacuum::maintain::(index, delay); + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + type O = Op, Dot>; + algorithm::vacuum::maintain::(index, delay); + } + (VectorKind::Vecf16, DistanceKind::L2) => { + type O = Op, L2>; + algorithm::vacuum::maintain::(index, delay); + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + type O = Op, Dot>; + algorithm::vacuum::maintain::(index, delay); + } + } std::ptr::null_mut() } diff --git a/src/vchordrq/index/am_options.rs b/src/index/am_options.rs similarity index 97% rename from src/vchordrq/index/am_options.rs rename to src/index/am_options.rs index 5c730ed..34c76a9 100644 --- a/src/vchordrq/index/am_options.rs +++ b/src/index/am_options.rs @@ -3,9 +3,9 @@ use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecOutput; use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use crate::datatype::typmod::Typmod; -use crate::vchordrq::types::{BorrowedVector, OwnedVector}; -use crate::vchordrq::types::{DistanceKind, VectorKind}; -use crate::vchordrq::types::{VchordrqIndexingOptions, VectorOptions}; +use crate::types::{BorrowedVector, OwnedVector}; +use crate::types::{DistanceKind, VectorKind}; +use crate::types::{VchordrqIndexingOptions, VectorOptions}; use distance::Distance; use pgrx::datum::FromDatum; use pgrx::heap_tuple::PgHeapTuple; diff --git a/src/vchordrq/index/am_scan.rs b/src/index/am_scan.rs similarity index 64% rename from src/vchordrq/index/am_scan.rs rename to src/index/am_scan.rs index 6e2d30d..83e62f3 100644 --- a/src/vchordrq/index/am_scan.rs +++ b/src/index/am_scan.rs @@ -1,12 +1,14 @@ use super::am_options::Opfamily; +use crate::algorithm::operator::Vector; +use crate::algorithm::operator::{Dot, L2, Op}; +use crate::algorithm::scan::scan; +use crate::gucs::executing::epsilon; +use crate::gucs::executing::max_scan_tuples; +use crate::gucs::executing::probes; use crate::postgres::PostgresRelation; -use crate::vchordrq::algorithm::scan::scan; -use crate::vchordrq::algorithm::tuples::Vector; -use crate::vchordrq::gucs::executing::epsilon; -use crate::vchordrq::gucs::executing::max_scan_tuples; -use crate::vchordrq::gucs::executing::probes; -use crate::vchordrq::types::OwnedVector; -use crate::vchordrq::types::VectorKind; +use crate::types::DistanceKind; +use crate::types::OwnedVector; +use crate::types::VectorKind; use distance::Distance; use half::f16; use std::num::NonZeroU64; @@ -74,12 +76,11 @@ pub fn scan_next(scanner: &mut Scanner, relation: PostgresRelation) -> Option<(N } = scanner { if let Some((vector, opfamily)) = vector.as_ref() { - match opfamily.vector_kind() { - VectorKind::Vecf32 => { - let vbase = scan::>( + match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + let vbase = scan::, L2>>( relation, VectOwned::::from_owned(vector.clone()), - opfamily.distance_kind(), probes(), epsilon(), ); @@ -94,11 +95,46 @@ pub fn scan_next(scanner: &mut Scanner, relation: PostgresRelation) -> Option<(N opfamily: *opfamily, }; } - VectorKind::Vecf16 => { - let vbase = scan::>( + (VectorKind::Vecf32, DistanceKind::Dot) => { + let vbase = scan::, Dot>>( + relation, + VectOwned::::from_owned(vector.clone()), + probes(), + epsilon(), + ); + *scanner = Scanner::Vbase { + vbase: if let Some(max_scan_tuples) = max_scan_tuples() { + Box::new(vbase.take(max_scan_tuples as usize)) + } else { + Box::new(vbase) + }, + threshold: *threshold, + recheck: *recheck, + opfamily: *opfamily, + }; + } + (VectorKind::Vecf16, DistanceKind::L2) => { + let vbase = scan::, L2>>( + relation, + VectOwned::::from_owned(vector.clone()), + probes(), + epsilon(), + ); + *scanner = Scanner::Vbase { + vbase: if let Some(max_scan_tuples) = max_scan_tuples() { + Box::new(vbase.take(max_scan_tuples as usize)) + } else { + Box::new(vbase) + }, + threshold: *threshold, + recheck: *recheck, + opfamily: *opfamily, + }; + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + let vbase = scan::, Dot>>( relation, VectOwned::::from_owned(vector.clone()), - opfamily.distance_kind(), probes(), epsilon(), ); diff --git a/src/vchordrq/index/functions.rs b/src/index/functions.rs similarity index 58% rename from src/vchordrq/index/functions.rs rename to src/index/functions.rs index 32e6d03..1f3b4e2 100644 --- a/src/vchordrq/index/functions.rs +++ b/src/index/functions.rs @@ -1,7 +1,9 @@ use super::am_options; +use crate::algorithm::operator::{Dot, L2, Op}; +use crate::algorithm::prewarm::prewarm; use crate::postgres::PostgresRelation; -use crate::vchordrq::algorithm::prewarm::prewarm; -use crate::vchordrq::types::VectorKind; +use crate::types::DistanceKind; +use crate::types::VectorKind; use half::f16; use pgrx::pg_sys::Oid; use pgrx_catalog::{PgAm, PgClass}; @@ -23,9 +25,19 @@ fn _vchordrq_prewarm(indexrelid: Oid, height: i32) -> String { let index = unsafe { pgrx::pg_sys::index_open(indexrelid, pgrx::pg_sys::ShareLock as _) }; let relation = unsafe { PostgresRelation::new(index) }; let opfamily = unsafe { am_options::opfamily(index) }; - let message = match opfamily.vector_kind() { - VectorKind::Vecf32 => prewarm::>(relation, height), - VectorKind::Vecf16 => prewarm::>(relation, height), + let message = match (opfamily.vector_kind(), opfamily.distance_kind()) { + (VectorKind::Vecf32, DistanceKind::L2) => { + prewarm::, L2>>(relation, height) + } + (VectorKind::Vecf32, DistanceKind::Dot) => { + prewarm::, Dot>>(relation, height) + } + (VectorKind::Vecf16, DistanceKind::L2) => { + prewarm::, L2>>(relation, height) + } + (VectorKind::Vecf16, DistanceKind::Dot) => { + prewarm::, Dot>>(relation, height) + } }; unsafe { pgrx::pg_sys::index_close(index, pgrx::pg_sys::ShareLock as _); diff --git a/src/vchordrq/index/mod.rs b/src/index/mod.rs similarity index 100% rename from src/vchordrq/index/mod.rs rename to src/index/mod.rs diff --git a/src/vchordrq/index/opclass.rs b/src/index/opclass.rs similarity index 100% rename from src/vchordrq/index/opclass.rs rename to src/index/opclass.rs diff --git a/src/index/utils.rs b/src/index/utils.rs new file mode 100644 index 0000000..18234ac --- /dev/null +++ b/src/index/utils.rs @@ -0,0 +1,34 @@ +use std::num::NonZeroU64; + +pub const fn pointer_to_ctid(pointer: NonZeroU64) -> pgrx::pg_sys::ItemPointerData { + let value = pointer.get(); + pgrx::pg_sys::ItemPointerData { + ip_blkid: pgrx::pg_sys::BlockIdData { + bi_hi: ((value >> 32) & 0xffff) as u16, + bi_lo: ((value >> 16) & 0xffff) as u16, + }, + ip_posid: (value & 0xffff) as u16, + } +} + +pub const fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> NonZeroU64 { + let mut value = 0; + value |= (ctid.ip_blkid.bi_hi as u64) << 32; + value |= (ctid.ip_blkid.bi_lo as u64) << 16; + value |= ctid.ip_posid as u64; + NonZeroU64::new(value).expect("invalid pointer") +} + +#[allow(dead_code)] +const fn soundness_check(a: pgrx::pg_sys::ItemPointerData) { + let b = ctid_to_pointer(a); + let c = pointer_to_ctid(b); + assert!(a.ip_blkid.bi_hi == c.ip_blkid.bi_hi); + assert!(a.ip_blkid.bi_lo == c.ip_blkid.bi_lo); + assert!(a.ip_posid == c.ip_posid); +} + +const _: () = soundness_check(pgrx::pg_sys::ItemPointerData { + ip_blkid: pgrx::pg_sys::BlockIdData { bi_hi: 1, bi_lo: 2 }, + ip_posid: 3, +}); diff --git a/src/lib.rs b/src/lib.rs index 018388d..187e3cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,18 @@ +#![feature(vec_pop_if)] #![allow(clippy::collapsible_else_if)] #![allow(clippy::infallible_destructuring_match)] #![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] +mod algorithm; mod datatype; +mod gucs; +mod index; mod postgres; mod projection; +mod types; mod upgrade; mod utils; -mod vchordrq; -mod vchordrqfscan; pgrx::pg_module_magic!(); pgrx::extension_sql_file!("./sql/bootstrap.sql", bootstrap); @@ -21,8 +24,8 @@ unsafe extern "C" fn _PG_init() { pgrx::error!("vchord must be loaded via shared_preload_libraries."); } unsafe { - vchordrq::init(); - vchordrqfscan::init(); + index::init(); + gucs::init(); #[cfg(any(feature = "pg13", feature = "pg14"))] pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchord".as_ptr()); diff --git a/src/postgres.rs b/src/postgres.rs index 6ff91e2..f68d0fa 100644 --- a/src/postgres.rs +++ b/src/postgres.rs @@ -1,4 +1,4 @@ -use algorithm::{Opaque, Page, PageGuard, RelationRead, RelationWrite}; +use crate::algorithm::{Opaque, Page, PageGuard, RelationRead, RelationWrite}; use std::mem::{MaybeUninit, offset_of}; use std::ops::{Deref, DerefMut}; use std::ptr::NonNull; @@ -43,12 +43,6 @@ impl PostgresPage { this } #[allow(dead_code)] - unsafe fn assume_init_mut(this: &mut MaybeUninit) -> &mut Self { - let this = unsafe { MaybeUninit::assume_init_mut(this) }; - assert_eq!(offset_of!(Self, opaque), this.header.pd_special as usize); - this - } - #[allow(dead_code)] fn clone_into_boxed(&self) -> Box { let mut result = Box::new_uninit(); unsafe { @@ -56,6 +50,23 @@ impl PostgresPage { result.assume_init() } } + #[allow(dead_code)] + fn reconstruct(&mut self, removes: &[u16]) { + let mut removes = removes.to_vec(); + removes.sort(); + removes.dedup(); + let n = removes.len(); + if n > 0 { + assert!(removes[n - 1] <= self.len()); + unsafe { + pgrx::pg_sys::PageIndexMultiDelete( + (self as *mut Self).cast(), + removes.as_ptr().cast_mut(), + removes.len() as _, + ); + } + } + } } impl Page for PostgresPage { @@ -142,25 +153,23 @@ impl Page for PostgresPage { pgrx::pg_sys::PageIndexTupleDeleteNoCompact((self as *mut Self).cast(), i); } } - fn reconstruct(&mut self, removes: &[u16]) { - let mut removes = removes.to_vec(); - removes.sort(); - removes.dedup(); - let n = removes.len(); - if n > 0 { - assert!(removes[n - 1] <= self.len()); - unsafe { - pgrx::pg_sys::PageIndexMultiDelete( - (self as *mut Self).cast(), - removes.as_ptr().cast_mut(), - removes.len() as _, - ); - } - } - } fn freespace(&self) -> u16 { unsafe { pgrx::pg_sys::PageGetFreeSpace((self as *const Self).cast_mut().cast()) as u16 } } + fn clear(&mut self) { + unsafe { + pgrx::pg_sys::PageInit( + (self as *mut PostgresPage as pgrx::pg_sys::Page).cast(), + pgrx::pg_sys::BLCKSZ as usize, + size_of::(), + ); + (&raw mut self.opaque).write(Opaque { + next: u32::MAX, + skip: u32::MAX, + }); + } + assert_eq!(offset_of!(Self, opaque), self.header.pd_special as usize); + } } const _: () = assert!(align_of::() == pgrx::pg_sys::MAXIMUM_ALIGNOF as usize); diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index c00aab1..7bc36b6 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -136,18 +136,10 @@ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_amhan CREATE FUNCTION vchordrq_prewarm(regclass, integer default 0) RETURNS TEXT STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_prewarm_wrapper'; -CREATE FUNCTION vchordrqfscan_amhandler(internal) RETURNS index_am_handler -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrqfscan_amhandler_wrapper'; - -CREATE FUNCTION vchordrqfscan_prewarm(regclass, integer default 0) RETURNS TEXT -STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrqfscan_prewarm_wrapper'; - -- List of access methods CREATE ACCESS METHOD vchordrq TYPE INDEX HANDLER vchordrq_amhandler; -CREATE ACCESS METHOD Vchordrqfscan TYPE INDEX HANDLER Vchordrqfscan_amhandler; - -- List of operator families CREATE OPERATOR FAMILY vector_l2_ops USING vchordrq; @@ -157,10 +149,6 @@ CREATE OPERATOR FAMILY halfvec_l2_ops USING vchordrq; CREATE OPERATOR FAMILY halfvec_ip_ops USING vchordrq; CREATE OPERATOR FAMILY halfvec_cosine_ops USING vchordrq; -CREATE OPERATOR FAMILY vector_l2_ops USING Vchordrqfscan; -CREATE OPERATOR FAMILY vector_ip_ops USING Vchordrqfscan; -CREATE OPERATOR FAMILY vector_cosine_ops USING Vchordrqfscan; - -- List of operator classes CREATE OPERATOR CLASS vector_l2_ops @@ -198,21 +186,3 @@ CREATE OPERATOR CLASS halfvec_cosine_ops OPERATOR 1 <=> (halfvec, halfvec) FOR ORDER BY float_ops, OPERATOR 2 <<=>> (halfvec, sphere_halfvec) FOR SEARCH, FUNCTION 1 _vchordrq_support_halfvec_cosine_ops(); - -CREATE OPERATOR CLASS vector_l2_ops - FOR TYPE vector USING Vchordrqfscan FAMILY vector_l2_ops AS - OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, - OPERATOR 2 <<->> (vector, sphere_vector) FOR SEARCH, - FUNCTION 1 _Vchordrqfscan_support_vector_l2_ops(); - -CREATE OPERATOR CLASS vector_ip_ops - FOR TYPE vector USING Vchordrqfscan FAMILY vector_ip_ops AS - OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, - OPERATOR 2 <<#>> (vector, sphere_vector) FOR SEARCH, - FUNCTION 1 _Vchordrqfscan_support_vector_ip_ops(); - -CREATE OPERATOR CLASS vector_cosine_ops - FOR TYPE vector USING Vchordrqfscan FAMILY vector_cosine_ops AS - OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, - OPERATOR 2 <<=>> (vector, sphere_vector) FOR SEARCH, - FUNCTION 1 _Vchordrqfscan_support_vector_cosine_ops(); diff --git a/src/vchordrq/types.rs b/src/types.rs similarity index 100% rename from src/vchordrq/types.rs rename to src/types.rs diff --git a/src/utils/k_means.rs b/src/utils/k_means.rs index 7b44a24..b1808c9 100644 --- a/src/utils/k_means.rs +++ b/src/utils/k_means.rs @@ -3,6 +3,7 @@ use half::f16; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use simd::Floating; +use simd::fast_scan::{any_pack, padding_pack}; pub fn k_means( parallelism: &P, @@ -81,39 +82,64 @@ fn rabitq_index( samples: &[Vec], centroids: &[Vec], ) -> Vec { - let mut a0 = Vec::new(); - let mut a1 = Vec::new(); - let mut a2 = Vec::new(); - let mut a3 = Vec::new(); - let mut a4 = Vec::new(); - for vectors in centroids.chunks(32) { - use simd::fast_scan::pack; - let x = std::array::from_fn::<_, 32, _>(|i| { - if let Some(vector) = vectors.get(i) { - rabitq::block::code(dims as _, vector) - } else { - rabitq::block::dummy_code(dims as _) - } + struct Branch { + dis_u_2: f32, + factor_ppc: f32, + factor_ip: f32, + factor_err: f32, + signs: Vec, + } + let branches = { + let mut branches = Vec::new(); + for centroid in centroids { + let code = rabitq::code(dims as _, centroid); + branches.push(Branch { + dis_u_2: code.dis_u_2, + factor_ppc: code.factor_ppc, + factor_ip: code.factor_ip, + factor_err: code.factor_err, + signs: code.signs, + }); + } + branches + }; + struct Block { + dis_u_2: [f32; 32], + factor_ppc: [f32; 32], + factor_ip: [f32; 32], + factor_err: [f32; 32], + elements: Vec<[u64; 2]>, + } + impl Block { + fn code(&self) -> (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[[u64; 2]]) { + ( + &self.dis_u_2, + &self.factor_ppc, + &self.factor_ip, + &self.factor_err, + &self.elements, + ) + } + } + let mut blocks = Vec::new(); + for chunk in branches.chunks(32) { + blocks.push(Block { + dis_u_2: any_pack(chunk.iter().map(|x| x.dis_u_2)), + factor_ppc: any_pack(chunk.iter().map(|x| x.factor_ppc)), + factor_ip: any_pack(chunk.iter().map(|x| x.factor_ip)), + factor_err: any_pack(chunk.iter().map(|x| x.factor_err)), + elements: padding_pack(chunk.iter().map(|x| rabitq::pack_to_u4(&x.signs))), }); - a0.push(x.each_ref().map(|x| x.dis_u_2)); - a1.push(x.each_ref().map(|x| x.factor_ppc)); - a2.push(x.each_ref().map(|x| x.factor_ip)); - a3.push(x.each_ref().map(|x| x.factor_err)); - a4.push(pack(dims.div_ceil(4) as _, x.map(|x| x.signs)).collect::>()); } parallelism .rayon_into_par_iter(0..n) .map(|i| { use distance::Distance; - let lut = rabitq::block::fscan_preprocess(&samples[i]); + let lut = rabitq::block::preprocess(&samples[i]); let mut result = (Distance::INFINITY, 0); for block in 0..c.div_ceil(32) { - let lowerbound = rabitq::block::fscan_process_lowerbound_l2( - dims as _, - &lut, - (&a0[block], &a1[block], &a2[block], &a3[block], &a4[block]), - 1.9, - ); + let lowerbound = + rabitq::block::process_lowerbound_l2(&lut, blocks[block].code(), 1.9); for j in block * 32..std::cmp::min(block * 32 + 32, c) { if lowerbound[j - block * 32] < result.0 { let dis = diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1b07dc6..85a84e0 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1,3 @@ pub mod k_means; pub mod parallelism; +pub mod pipe; diff --git a/src/utils/pipe.rs b/src/utils/pipe.rs new file mode 100644 index 0000000..18cfc37 --- /dev/null +++ b/src/utils/pipe.rs @@ -0,0 +1,14 @@ +pub trait Pipe { + fn pipe(self, f: impl FnOnce(Self) -> T) -> T + where + Self: Sized; +} + +impl Pipe for S { + fn pipe(self, f: impl FnOnce(Self) -> T) -> T + where + Self: Sized, + { + f(self) + } +} diff --git a/src/vchordrq/algorithm/insert.rs b/src/vchordrq/algorithm/insert.rs deleted file mode 100644 index f625dca..0000000 --- a/src/vchordrq/algorithm/insert.rs +++ /dev/null @@ -1,221 +0,0 @@ -use crate::vchordrq::algorithm::rabitq; -use crate::vchordrq::algorithm::tuples::*; -use crate::vchordrq::algorithm::vectors; -use crate::vchordrq::types::DistanceKind; -use algorithm::{Page, PageGuard, RelationWrite}; -use always_equal::AlwaysEqual; -use distance::Distance; -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::num::NonZeroU64; -use vector::VectorBorrowed; - -pub fn insert( - relation: impl RelationWrite + Clone, - payload: NonZeroU64, - vector: V, - distance_kind: DistanceKind, - in_building: bool, -) { - let vector = vector.as_borrowed(); - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let dims = meta_tuple.dims; - assert_eq!(dims, vector.dims(), "invalid vector dimensions"); - let vector = V::random_projection(vector); - let vector = vector.as_borrowed(); - let is_residual = meta_tuple.is_residual; - let default_lut = if !is_residual { - Some(V::rabitq_fscan_preprocess(vector)) - } else { - None - }; - let h0_vector = { - let (metadata, slices) = V::vector_split(vector); - let mut chain = Err(metadata); - for i in (0..slices.len()).rev() { - let tuple = rkyv::to_bytes::<_, 8192>(&VectorTuple:: { - slice: slices[i].to_vec(), - payload: Some(payload), - chain, - }) - .unwrap(); - chain = Ok(append( - relation.clone(), - meta_tuple.vectors_first, - &tuple, - true, - true, - true, - )); - } - chain.ok().unwrap() - }; - let h0_payload = payload; - let mut list = { - let Some((_, original)) = vectors::vector_dist::( - relation.clone(), - vector, - meta_tuple.mean, - None, - None, - is_residual, - ) else { - panic!("data corruption") - }; - (meta_tuple.first, original) - }; - let make_list = |list: (u32, Option)| { - let mut results = Vec::new(); - { - let lut = if is_residual { - &V::rabitq_fscan_preprocess( - V::residual(vector, list.1.as_ref().map(|x| x.as_borrowed()).unwrap()) - .as_borrowed(), - ) - } else { - default_lut.as_ref().unwrap() - }; - let mut current = list.0; - while current != u32::MAX { - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let lowerbounds = rabitq::process_lowerbound( - distance_kind, - dims, - lut, - ( - h1_tuple.dis_u_2, - h1_tuple.factor_ppc, - h1_tuple.factor_ip, - h1_tuple.factor_err, - &h1_tuple.t, - ), - 1.9, - ); - results.push(( - Reverse(lowerbounds), - AlwaysEqual(h1_tuple.mean), - AlwaysEqual(h1_tuple.first), - )); - } - current = h1_guard.get_opaque().next; - } - } - let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); - { - while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); - let Some((Some(dis_u), original)) = vectors::vector_dist::( - relation.clone(), - vector, - mean, - None, - Some(distance_kind), - is_residual, - ) else { - panic!("data corruption") - }; - cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(original))); - } - let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop().unwrap(); - (first, mean) - } - }; - for _ in (1..meta_tuple.height_of_root).rev() { - list = make_list(list); - } - let code = if is_residual { - V::rabitq_code( - dims, - V::residual(vector, list.1.as_ref().map(|x| x.as_borrowed()).unwrap()).as_borrowed(), - ) - } else { - V::rabitq_code(dims, vector) - }; - let tuple = rkyv::to_bytes::<_, 8192>(&Height0Tuple { - mean: h0_vector, - payload: h0_payload, - dis_u_2: code.dis_u_2, - factor_ppc: code.factor_ppc, - factor_ip: code.factor_ip, - factor_err: code.factor_err, - t: code.t(), - }) - .unwrap(); - append( - relation.clone(), - list.0, - &tuple, - false, - in_building, - in_building, - ); -} - -fn append( - relation: impl RelationWrite, - first: u32, - tuple: &[u8], - tracking_freespace: bool, - skipping_traversal: bool, - updating_skip: bool, -) -> (u32, u16) { - if tracking_freespace { - if let Some(mut write) = relation.search(tuple.len()) { - let i = write.alloc(tuple).unwrap(); - return (write.id(), i); - } - } - assert!(first != u32::MAX); - let mut current = first; - loop { - let read = relation.read(current); - if read.freespace() as usize >= tuple.len() || read.get_opaque().next == u32::MAX { - drop(read); - let mut write = relation.write(current, tracking_freespace); - if let Some(i) = write.alloc(tuple) { - return (current, i); - } - if write.get_opaque().next == u32::MAX { - let mut extend = relation.extend(tracking_freespace); - write.get_opaque_mut().next = extend.id(); - drop(write); - if let Some(i) = extend.alloc(tuple) { - let result = (extend.id(), i); - drop(extend); - if updating_skip { - let mut past = relation.write(first, tracking_freespace); - let skip = &mut past.get_opaque_mut().skip; - assert!(*skip != u32::MAX); - *skip = std::cmp::max(*skip, result.0); - } - return result; - } else { - panic!("a tuple cannot even be fit in a fresh page"); - } - } - if skipping_traversal && current == first && write.get_opaque().skip != first { - current = write.get_opaque().skip; - } else { - current = write.get_opaque().next; - } - } else { - if skipping_traversal && current == first && read.get_opaque().skip != first { - current = read.get_opaque().skip; - } else { - current = read.get_opaque().next; - } - } - } -} diff --git a/src/vchordrq/algorithm/mod.rs b/src/vchordrq/algorithm/mod.rs deleted file mode 100644 index 88239a8..0000000 --- a/src/vchordrq/algorithm/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod build; -pub mod insert; -pub mod prewarm; -pub mod rabitq; -pub mod scan; -pub mod tuples; -pub mod vacuum; -pub mod vectors; diff --git a/src/vchordrq/algorithm/prewarm.rs b/src/vchordrq/algorithm/prewarm.rs deleted file mode 100644 index 6a7dc25..0000000 --- a/src/vchordrq/algorithm/prewarm.rs +++ /dev/null @@ -1,82 +0,0 @@ -use crate::vchordrq::algorithm::tuples::*; -use crate::vchordrq::algorithm::vectors; -use algorithm::{Page, RelationRead}; -use std::fmt::Write; - -pub fn prewarm(relation: impl RelationRead + Clone, height: i32) -> String { - let mut message = String::new(); - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - writeln!(message, "height of root: {}", meta_tuple.height_of_root).unwrap(); - let prewarm_max_height = if height < 0 { 0 } else { height as u32 }; - if prewarm_max_height > meta_tuple.height_of_root { - return message; - } - let mut lists = { - let mut results = Vec::new(); - let counter = 1_usize; - { - vectors::vector_warm::(relation.clone(), meta_tuple.mean); - results.push(meta_tuple.first); - } - writeln!(message, "number of tuples: {}", results.len()).unwrap(); - writeln!(message, "number of pages: {}", counter).unwrap(); - results - }; - let mut make_lists = |lists| { - let mut counter = 0_usize; - let mut results = Vec::new(); - for list in lists { - let mut current = list; - while current != u32::MAX { - counter += 1; - pgrx::check_for_interrupts!(); - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - vectors::vector_warm::(relation.clone(), h1_tuple.mean); - results.push(h1_tuple.first); - } - current = h1_guard.get_opaque().next; - } - } - writeln!(message, "number of tuples: {}", results.len()).unwrap(); - writeln!(message, "number of pages: {}", counter).unwrap(); - results - }; - for _ in (std::cmp::max(1, prewarm_max_height)..meta_tuple.height_of_root).rev() { - lists = make_lists(lists); - } - if prewarm_max_height == 0 { - let mut counter = 0_usize; - let mut results = Vec::new(); - for list in lists { - let mut current = list; - while current != u32::MAX { - counter += 1; - pgrx::check_for_interrupts!(); - let h0_guard = relation.read(current); - for i in 1..=h0_guard.len() { - let _h0_tuple = h0_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - results.push(()); - } - current = h0_guard.get_opaque().next; - } - } - writeln!(message, "number of tuples: {}", results.len()).unwrap(); - writeln!(message, "number of pages: {}", counter).unwrap(); - } - message -} diff --git a/src/vchordrq/algorithm/rabitq.rs b/src/vchordrq/algorithm/rabitq.rs deleted file mode 100644 index 4f406e1..0000000 --- a/src/vchordrq/algorithm/rabitq.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::vchordrq::types::DistanceKind; -use distance::Distance; - -pub use rabitq::binary::Code; -pub use rabitq::binary::Lut; -pub use rabitq::binary::code; -pub use rabitq::binary::preprocess; -pub use rabitq::binary::{process_lowerbound_dot, process_lowerbound_l2}; - -pub fn process_lowerbound( - distance_kind: DistanceKind, - dims: u32, - lut: &Lut, - code: (f32, f32, f32, f32, &[u64]), - epsilon: f32, -) -> Distance { - match distance_kind { - DistanceKind::L2 => process_lowerbound_l2(dims, lut, code, epsilon), - DistanceKind::Dot => process_lowerbound_dot(dims, lut, code, epsilon), - } -} diff --git a/src/vchordrq/algorithm/scan.rs b/src/vchordrq/algorithm/scan.rs deleted file mode 100644 index e6915f0..0000000 --- a/src/vchordrq/algorithm/scan.rs +++ /dev/null @@ -1,189 +0,0 @@ -use crate::vchordrq::algorithm::rabitq; -use crate::vchordrq::algorithm::tuples::*; -use crate::vchordrq::algorithm::vectors; -use crate::vchordrq::types::DistanceKind; -use algorithm::{Page, RelationRead}; -use always_equal::AlwaysEqual; -use distance::Distance; -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::num::NonZeroU64; -use vector::VectorBorrowed; - -pub fn scan( - relation: impl RelationRead + Clone, - vector: V, - distance_kind: DistanceKind, - probes: Vec, - epsilon: f32, -) -> impl Iterator { - let vector = vector.as_borrowed(); - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let dims = meta_tuple.dims; - let height_of_root = meta_tuple.height_of_root; - assert_eq!(dims, vector.dims(), "invalid vector dimensions"); - assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes"); - let vector = V::random_projection(vector); - let is_residual = meta_tuple.is_residual; - let default_lut = if !is_residual { - Some(V::rabitq_fscan_preprocess(vector.as_borrowed())) - } else { - None - }; - let mut lists: Vec<_> = vec![{ - let Some((_, original)) = vectors::vector_dist::( - relation.clone(), - vector.as_borrowed(), - meta_tuple.mean, - None, - None, - is_residual, - ) else { - panic!("data corruption") - }; - (meta_tuple.first, original) - }]; - let make_lists = |lists: Vec<(u32, Option)>, probes| { - let mut results = Vec::new(); - for list in lists { - let lut = if is_residual { - &V::rabitq_fscan_preprocess( - V::residual( - vector.as_borrowed(), - list.1.as_ref().map(|x| x.as_borrowed()).unwrap(), - ) - .as_borrowed(), - ) - } else { - default_lut.as_ref().unwrap() - }; - let mut current = list.0; - while current != u32::MAX { - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let lowerbounds = rabitq::process_lowerbound( - distance_kind, - dims, - lut, - ( - h1_tuple.dis_u_2, - h1_tuple.factor_ppc, - h1_tuple.factor_ip, - h1_tuple.factor_err, - &h1_tuple.t, - ), - epsilon, - ); - results.push(( - Reverse(lowerbounds), - AlwaysEqual(h1_tuple.mean), - AlwaysEqual(h1_tuple.first), - )); - } - current = h1_guard.get_opaque().next; - } - } - let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); - std::iter::from_fn(|| { - while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); - let Some((Some(dis_u), original)) = vectors::vector_dist::( - relation.clone(), - vector.as_borrowed(), - mean, - None, - Some(distance_kind), - is_residual, - ) else { - panic!("data corruption") - }; - cache.push((Reverse(dis_u), AlwaysEqual(first), AlwaysEqual(original))); - } - let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop()?; - Some((first, mean)) - }) - .take(probes as usize) - .collect() - }; - for i in (1..meta_tuple.height_of_root).rev() { - lists = make_lists(lists, probes[i as usize - 1]); - } - drop(meta_guard); - { - let mut results = Vec::new(); - for list in lists { - let lut = if is_residual { - &V::rabitq_fscan_preprocess( - V::residual( - vector.as_borrowed(), - list.1.as_ref().map(|x| x.as_borrowed()).unwrap(), - ) - .as_borrowed(), - ) - } else { - default_lut.as_ref().unwrap() - }; - let mut current = list.0; - while current != u32::MAX { - let h0_guard = relation.read(current); - for i in 1..=h0_guard.len() { - let h0_tuple = h0_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let lowerbounds = rabitq::process_lowerbound( - distance_kind, - dims, - lut, - ( - h0_tuple.dis_u_2, - h0_tuple.factor_ppc, - h0_tuple.factor_ip, - h0_tuple.factor_err, - &h0_tuple.t, - ), - epsilon, - ); - results.push(( - Reverse(lowerbounds), - AlwaysEqual(h0_tuple.mean), - AlwaysEqual(h0_tuple.payload), - )); - } - current = h0_guard.get_opaque().next; - } - } - let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _)>::new(); - std::iter::from_fn(move || { - while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap(); - let Some((Some(dis_u), _)) = vectors::vector_dist::( - relation.clone(), - vector.as_borrowed(), - mean, - Some(pay_u), - Some(distance_kind), - false, - ) else { - continue; - }; - cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); - } - let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; - Some((dis_u, pay_u)) - }) - } -} diff --git a/src/vchordrq/algorithm/tuples.rs b/src/vchordrq/algorithm/tuples.rs deleted file mode 100644 index 23fe664..0000000 --- a/src/vchordrq/algorithm/tuples.rs +++ /dev/null @@ -1,270 +0,0 @@ -use std::num::NonZeroU64; - -use super::rabitq::{self, Code, Lut}; -use crate::vchordrq::types::DistanceKind; -use crate::vchordrq::types::OwnedVector; -use half::f16; -use rkyv::{Archive, ArchiveUnsized, CheckBytes, Deserialize, Serialize}; -use simd::Floating; -use vector::VectorOwned; -use vector::vect::VectOwned; - -pub trait Vector: VectorOwned { - type Metadata: Copy - + Serialize< - rkyv::ser::serializers::CompositeSerializer< - rkyv::ser::serializers::AlignedSerializer, - rkyv::ser::serializers::FallbackScratch< - rkyv::ser::serializers::HeapScratch<8192>, - rkyv::ser::serializers::AllocScratch, - >, - rkyv::ser::serializers::SharedSerializeMap, - >, - > + for<'a> CheckBytes>; - type Element: Copy - + Serialize< - rkyv::ser::serializers::CompositeSerializer< - rkyv::ser::serializers::AlignedSerializer, - rkyv::ser::serializers::FallbackScratch< - rkyv::ser::serializers::HeapScratch<8192>, - rkyv::ser::serializers::AllocScratch, - >, - rkyv::ser::serializers::SharedSerializeMap, - >, - > + for<'a> CheckBytes> - + Archive; - - fn metadata_from_archived( - archived: &::Archived, - ) -> Self::Metadata; - - fn vector_split(vector: Self::Borrowed<'_>) -> (Self::Metadata, Vec<&[Self::Element]>); - fn vector_merge(metadata: Self::Metadata, slice: &[Self::Element]) -> Self; - fn from_owned(vector: OwnedVector) -> Self; - - type DistanceAccumulator; - fn distance_begin(distance_kind: DistanceKind) -> Self::DistanceAccumulator; - fn distance_next( - accumulator: &mut Self::DistanceAccumulator, - left: &[Self::Element], - right: &[Self::Element], - ); - fn distance_end( - accumulator: Self::DistanceAccumulator, - left: Self::Metadata, - right: Self::Metadata, - ) -> f32; - - fn random_projection(vector: Self::Borrowed<'_>) -> Self; - - fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self; - - fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut; - - fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code; - - fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec; - - fn build_from_vecf32(x: &[f32]) -> Self; -} - -impl Vector for VectOwned { - type Metadata = (); - - type Element = f32; - - fn metadata_from_archived(_: &::Archived) -> Self::Metadata {} - - fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f32]>) { - let vector = vector.slice(); - ((), match vector.len() { - 0..=960 => vec![vector], - 961..=1280 => vec![&vector[..640], &vector[640..]], - 1281.. => vector.chunks(1920).collect(), - }) - } - - fn vector_merge((): Self::Metadata, slice: &[Self::Element]) -> Self { - VectOwned::new(slice.to_vec()) - } - - fn from_owned(vector: OwnedVector) -> Self { - match vector { - OwnedVector::Vecf32(x) => x, - _ => unreachable!(), - } - } - - type DistanceAccumulator = (DistanceKind, f32); - fn distance_begin(distance_kind: DistanceKind) -> Self::DistanceAccumulator { - (distance_kind, 0.0) - } - fn distance_next( - accumulator: &mut Self::DistanceAccumulator, - left: &[Self::Element], - right: &[Self::Element], - ) { - match accumulator.0 { - DistanceKind::L2 => accumulator.1 += f32::reduce_sum_of_d2(left, right), - DistanceKind::Dot => accumulator.1 += -f32::reduce_sum_of_xy(left, right), - } - } - fn distance_end( - accumulator: Self::DistanceAccumulator, - (): Self::Metadata, - (): Self::Metadata, - ) -> f32 { - accumulator.1 - } - - fn random_projection(vector: Self::Borrowed<'_>) -> Self { - Self::new(crate::projection::project(vector.slice())) - } - - fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self { - Self::new(Floating::vector_sub(vector.slice(), center.slice())) - } - - fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut { - rabitq::preprocess(vector.slice()) - } - - fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { - rabitq::code(dims, vector.slice()) - } - - fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { - vector.slice().to_vec() - } - - fn build_from_vecf32(x: &[f32]) -> Self { - Self::new(x.to_vec()) - } -} - -impl Vector for VectOwned { - type Metadata = (); - - type Element = f16; - - fn metadata_from_archived(_: &::Archived) -> Self::Metadata {} - - fn vector_split(vector: Self::Borrowed<'_>) -> ((), Vec<&[f16]>) { - let vector = vector.slice(); - ((), match vector.len() { - 0..=1920 => vec![vector], - 1921..=2560 => vec![&vector[..1280], &vector[1280..]], - 2561.. => vector.chunks(3840).collect(), - }) - } - - fn vector_merge((): Self::Metadata, slice: &[Self::Element]) -> Self { - VectOwned::new(slice.to_vec()) - } - - fn from_owned(vector: OwnedVector) -> Self { - match vector { - OwnedVector::Vecf16(x) => x, - _ => unreachable!(), - } - } - - type DistanceAccumulator = (DistanceKind, f32); - fn distance_begin(distance_kind: DistanceKind) -> Self::DistanceAccumulator { - (distance_kind, 0.0) - } - fn distance_next( - accumulator: &mut Self::DistanceAccumulator, - left: &[Self::Element], - right: &[Self::Element], - ) { - match accumulator.0 { - DistanceKind::L2 => accumulator.1 += f16::reduce_sum_of_d2(left, right), - DistanceKind::Dot => accumulator.1 += -f16::reduce_sum_of_xy(left, right), - } - } - fn distance_end( - accumulator: Self::DistanceAccumulator, - (): Self::Metadata, - (): Self::Metadata, - ) -> f32 { - accumulator.1 - } - - fn random_projection(vector: Self::Borrowed<'_>) -> Self { - Self::new(f16::vector_from_f32(&crate::projection::project( - &f16::vector_to_f32(vector.slice()), - ))) - } - - fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self { - Self::new(Floating::vector_sub(vector.slice(), center.slice())) - } - - fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut { - rabitq::preprocess(&f16::vector_to_f32(vector.slice())) - } - - fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code { - rabitq::code(dims, &f16::vector_to_f32(vector.slice())) - } - - fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec { - f16::vector_to_f32(vector.slice()) - } - - fn build_from_vecf32(x: &[f32]) -> Self { - Self::new(f16::vector_from_f32(x)) - } -} - -#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] -#[archive(check_bytes)] -pub struct MetaTuple { - pub dims: u32, - pub height_of_root: u32, - pub is_residual: bool, - pub vectors_first: u32, - // raw vector - pub mean: (u32, u16), - // for meta tuple, it's pointers to next level - pub first: u32, -} - -#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] -#[archive(check_bytes)] -pub struct VectorTuple { - pub slice: Vec, - pub payload: Option, - pub chain: Result<(u32, u16), V::Metadata>, -} - -#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] -#[archive(check_bytes)] -pub struct Height1Tuple { - // raw vector - pub mean: (u32, u16), - // for height 1 tuple, it's pointers to next level - pub first: u32, - // RaBitQ algorithm - pub dis_u_2: f32, - pub factor_ppc: f32, - pub factor_ip: f32, - pub factor_err: f32, - pub t: Vec, -} - -#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] -#[archive(check_bytes)] -pub struct Height0Tuple { - // raw vector - pub mean: (u32, u16), - // for height 0 tuple, it's pointers to heap relation - pub payload: NonZeroU64, - // RaBitQ algorithm - pub dis_u_2: f32, - pub factor_ppc: f32, - pub factor_ip: f32, - pub factor_err: f32, - pub t: Vec, -} diff --git a/src/vchordrq/algorithm/vacuum.rs b/src/vchordrq/algorithm/vacuum.rs deleted file mode 100644 index ee97ca6..0000000 --- a/src/vchordrq/algorithm/vacuum.rs +++ /dev/null @@ -1,112 +0,0 @@ -use crate::vchordrq::algorithm::tuples::*; -use algorithm::{Page, RelationWrite}; -use std::num::NonZeroU64; - -pub fn vacuum( - relation: impl RelationWrite, - delay: impl Fn(), - callback: impl Fn(NonZeroU64) -> bool, -) { - // step 1: vacuum height_0_tuple - { - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let mut firsts = vec![meta_tuple.first]; - let make_firsts = |firsts| { - let mut results = Vec::new(); - for first in firsts { - let mut current = first; - while current != u32::MAX { - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - results.push(h1_tuple.first); - } - current = h1_guard.get_opaque().next; - } - } - results - }; - for _ in (1..meta_tuple.height_of_root).rev() { - firsts = make_firsts(firsts); - } - for first in firsts { - let mut current = first; - while current != u32::MAX { - delay(); - let mut h0_guard = relation.write(current, false); - let mut reconstruct_removes = Vec::new(); - for i in 1..=h0_guard.len() { - let h0_tuple = h0_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - if callback(h0_tuple.payload) { - reconstruct_removes.push(i); - } - } - h0_guard.reconstruct(&reconstruct_removes); - current = h0_guard.get_opaque().next; - } - } - } - // step 2: vacuum vector_tuple - { - let mut current = { - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - meta_tuple.vectors_first - }; - while current != u32::MAX { - delay(); - let read = relation.read(current); - let flag = 'flag: { - for i in 1..=read.len() { - let Some(vector_tuple) = read.get(i) else { - continue; - }; - let vector_tuple = - unsafe { rkyv::archived_root::>(vector_tuple) }; - if let Some(payload) = vector_tuple.payload.as_ref().copied() { - if callback(payload) { - break 'flag true; - } - } - } - false - }; - if flag { - drop(read); - let mut write = relation.write(current, true); - for i in 1..=write.len() { - let Some(vector_tuple) = write.get(i) else { - continue; - }; - let vector_tuple = - unsafe { rkyv::archived_root::>(vector_tuple) }; - if let Some(payload) = vector_tuple.payload.as_ref().copied() { - if callback(payload) { - write.free(i); - } - } - } - current = write.get_opaque().next; - } else { - current = read.get_opaque().next; - } - } - } -} diff --git a/src/vchordrq/algorithm/vectors.rs b/src/vchordrq/algorithm/vectors.rs deleted file mode 100644 index 72e448a..0000000 --- a/src/vchordrq/algorithm/vectors.rs +++ /dev/null @@ -1,76 +0,0 @@ -use super::tuples::Vector; -use crate::vchordrq::algorithm::tuples::VectorTuple; -use crate::vchordrq::types::DistanceKind; -use algorithm::{Page, RelationRead}; -use distance::Distance; -use std::num::NonZeroU64; - -pub fn vector_dist( - relation: impl RelationRead, - vector: V::Borrowed<'_>, - mean: (u32, u16), - payload: Option, - for_distance: Option, - for_original: bool, -) -> Option<(Option, Option)> { - if for_distance.is_none() && !for_original && payload.is_none() { - return Some((None, None)); - } - let (left_metadata, slices) = V::vector_split(vector); - let mut cursor = Ok(mean); - let mut result = for_distance.map(|x| V::distance_begin(x)); - let mut original = Vec::new(); - for i in 0..slices.len() { - let Ok(mean) = cursor else { - // fails consistency check - return None; - }; - let vector_guard = relation.read(mean.0); - let Some(vector_tuple) = vector_guard.get(mean.1) else { - // fails consistency check - return None; - }; - let vector_tuple = unsafe { rkyv::archived_root::>(vector_tuple) }; - if vector_tuple.payload != payload { - // fails consistency check - return None; - } - if let Some(result) = result.as_mut() { - V::distance_next(result, slices[i], &vector_tuple.slice); - } - if for_original { - original.extend_from_slice(&vector_tuple.slice); - } - cursor = match &vector_tuple.chain { - rkyv::result::ArchivedResult::Ok(x) => Ok(*x), - rkyv::result::ArchivedResult::Err(x) => Err(V::metadata_from_archived(x)), - }; - } - let Err(right_metadata) = cursor else { - panic!("data corruption") - }; - Some(( - result.map(|r| Distance::from_f32(V::distance_end(r, left_metadata, right_metadata))), - for_original.then(|| V::vector_merge(right_metadata, &original)), - )) -} - -pub fn vector_warm(relation: impl RelationRead, mean: (u32, u16)) { - let mut cursor = Ok(mean); - while let Ok(mean) = cursor { - let vector_guard = relation.read(mean.0); - let Some(vector_tuple) = vector_guard.get(mean.1) else { - // fails consistency check - return; - }; - let vector_tuple = unsafe { rkyv::archived_root::>(vector_tuple) }; - if vector_tuple.payload.is_some() { - // fails consistency check - return; - } - cursor = match &vector_tuple.chain { - rkyv::result::ArchivedResult::Ok(x) => Ok(*x), - rkyv::result::ArchivedResult::Err(x) => Err(V::metadata_from_archived(x)), - }; - } -} diff --git a/src/vchordrq/index/utils.rs b/src/vchordrq/index/utils.rs deleted file mode 100644 index 726a597..0000000 --- a/src/vchordrq/index/utils.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::num::NonZeroU64; - -pub fn pointer_to_ctid(pointer: NonZeroU64) -> pgrx::pg_sys::ItemPointerData { - let value = pointer.get(); - pgrx::pg_sys::ItemPointerData { - ip_blkid: pgrx::pg_sys::BlockIdData { - bi_hi: ((value >> 32) & 0xffff) as u16, - bi_lo: ((value >> 16) & 0xffff) as u16, - }, - ip_posid: (value & 0xffff) as u16, - } -} - -pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> NonZeroU64 { - let mut value = 0; - value |= (ctid.ip_blkid.bi_hi as u64) << 32; - value |= (ctid.ip_blkid.bi_lo as u64) << 16; - value |= ctid.ip_posid as u64; - NonZeroU64::new(value).expect("invalid pointer") -} diff --git a/src/vchordrq/mod.rs b/src/vchordrq/mod.rs deleted file mode 100644 index c2ae945..0000000 --- a/src/vchordrq/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod algorithm; -mod gucs; -mod index; -mod types; - -pub unsafe fn init() { - unsafe { - index::init(); - gucs::init(); - } -} diff --git a/src/vchordrqfscan/algorithm/build.rs b/src/vchordrqfscan/algorithm/build.rs deleted file mode 100644 index e46d6a7..0000000 --- a/src/vchordrqfscan/algorithm/build.rs +++ /dev/null @@ -1,445 +0,0 @@ -use crate::vchordrqfscan::algorithm::rabitq; -use crate::vchordrqfscan::algorithm::tuples::*; -use crate::vchordrqfscan::index::am_options::Opfamily; -use crate::vchordrqfscan::types::DistanceKind; -use crate::vchordrqfscan::types::VchordrqfscanBuildOptions; -use crate::vchordrqfscan::types::VchordrqfscanExternalBuildOptions; -use crate::vchordrqfscan::types::VchordrqfscanIndexingOptions; -use crate::vchordrqfscan::types::VchordrqfscanInternalBuildOptions; -use crate::vchordrqfscan::types::VectorOptions; -use algorithm::{Page, PageGuard, RelationWrite}; -use rand::Rng; -use rkyv::ser::serializers::AllocSerializer; -use simd::Floating; -use std::marker::PhantomData; -use std::num::NonZeroU64; -use std::sync::Arc; - -pub trait HeapRelation { - fn traverse(&self, progress: bool, callback: F) - where - F: FnMut((NonZeroU64, Vec)); - fn opfamily(&self) -> Opfamily; -} - -pub trait Reporter { - fn tuples_total(&mut self, tuples_total: u64); -} - -pub fn build( - vector_options: VectorOptions, - vchordrqfscan_options: VchordrqfscanIndexingOptions, - heap_relation: T, - relation: impl RelationWrite, - mut reporter: R, -) { - let dims = vector_options.dims; - let is_residual = - vchordrqfscan_options.residual_quantization && vector_options.d == DistanceKind::L2; - let structures = match vchordrqfscan_options.build { - VchordrqfscanBuildOptions::External(external_build) => Structure::extern_build( - vector_options.clone(), - heap_relation.opfamily(), - external_build.clone(), - ), - VchordrqfscanBuildOptions::Internal(internal_build) => { - let mut tuples_total = 0_u64; - let samples = { - let mut rand = rand::thread_rng(); - let max_number_of_samples = internal_build - .lists - .last() - .unwrap() - .saturating_mul(internal_build.sampling_factor); - let mut samples = Vec::new(); - let mut number_of_samples = 0_u32; - heap_relation.traverse(false, |(_, vector)| { - assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); - if number_of_samples < max_number_of_samples { - samples.push(vector); - number_of_samples += 1; - } else { - let index = rand.gen_range(0..max_number_of_samples) as usize; - samples[index] = vector; - } - tuples_total += 1; - }); - samples - }; - reporter.tuples_total(tuples_total); - Structure::internal_build(vector_options.clone(), internal_build.clone(), samples) - } - }; - let mut meta = Tape::create(&relation, false); - assert_eq!(meta.first(), 0); - let mut vectors = Tape::create(&relation, true); - let mut pointer_of_means = Vec::>::new(); - for i in 0..structures.len() { - let mut level = Vec::new(); - for j in 0..structures[i].len() { - let pointer = vectors.push(&VectorTuple { - payload: None, - vector: structures[i].means[j].clone(), - }); - level.push(pointer); - } - pointer_of_means.push(level); - } - let mut pointer_of_firsts = Vec::>::new(); - for i in 0..structures.len() { - let mut level = Vec::new(); - for j in 0..structures[i].len() { - if i == 0 { - let tape = Tape::::create(&relation, false); - level.push(tape.first()); - } else { - let mut tape = Tape::::create(&relation, false); - let mut cache = Vec::new(); - let h2_mean = &structures[i].means[j]; - let h2_children = &structures[i].children[j]; - for child in h2_children.iter().copied() { - let h1_mean = &structures[i - 1].means[child as usize]; - let code = if is_residual { - rabitq::code(dims, &f32::vector_sub(h1_mean, h2_mean)) - } else { - rabitq::code(dims, h1_mean) - }; - cache.push((child, code)); - if cache.len() == 32 { - let group = std::mem::take(&mut cache); - let codes = std::array::from_fn(|k| group[k].1.clone()); - let packed = rabitq::pack_codes(dims, codes); - tape.push(&Height1Tuple { - mask: [true; 32], - mean: std::array::from_fn(|k| { - pointer_of_means[i - 1][group[k].0 as usize] - }), - first: std::array::from_fn(|k| { - pointer_of_firsts[i - 1][group[k].0 as usize] - }), - dis_u_2: packed.dis_u_2, - factor_ppc: packed.factor_ppc, - factor_ip: packed.factor_ip, - factor_err: packed.factor_err, - t: packed.t, - }); - } - } - if !cache.is_empty() { - let group = std::mem::take(&mut cache); - let codes = std::array::from_fn(|k| { - if k < group.len() { - group[k].1.clone() - } else { - rabitq::dummy_code(dims) - } - }); - let packed = rabitq::pack_codes(dims, codes); - tape.push(&Height1Tuple { - mask: std::array::from_fn(|k| k < group.len()), - mean: std::array::from_fn(|k| { - if k < group.len() { - pointer_of_means[i - 1][group[k].0 as usize] - } else { - Default::default() - } - }), - first: std::array::from_fn(|k| { - if k < group.len() { - pointer_of_firsts[i - 1][group[k].0 as usize] - } else { - Default::default() - } - }), - dis_u_2: packed.dis_u_2, - factor_ppc: packed.factor_ppc, - factor_ip: packed.factor_ip, - factor_err: packed.factor_err, - t: packed.t, - }); - } - level.push(tape.first()); - } - } - pointer_of_firsts.push(level); - } - meta.push(&MetaTuple { - dims, - height_of_root: structures.len() as u32, - is_residual, - vectors_first: vectors.first(), - mean: pointer_of_means.last().unwrap()[0], - first: pointer_of_firsts.last().unwrap()[0], - }); -} - -struct Structure { - means: Vec>, - children: Vec>, -} - -impl Structure { - fn len(&self) -> usize { - self.children.len() - } - fn internal_build( - vector_options: VectorOptions, - internal_build: VchordrqfscanInternalBuildOptions, - mut samples: Vec>, - ) -> Vec { - use std::iter::once; - for sample in samples.iter_mut() { - *sample = crate::projection::project(sample); - } - let mut result = Vec::::new(); - for w in internal_build.lists.iter().rev().copied().chain(once(1)) { - let means = crate::utils::parallelism::RayonParallelism::scoped( - internal_build.build_threads as _, - Arc::new(|| { - pgrx::check_for_interrupts!(); - }), - |parallelism| { - crate::utils::k_means::k_means( - parallelism, - w as usize, - vector_options.dims as usize, - if let Some(structure) = result.last() { - &structure.means - } else { - &samples - }, - internal_build.spherical_centroids, - 10, - ) - }, - ) - .expect("failed to create thread pool"); - if let Some(structure) = result.last() { - let mut children = vec![Vec::new(); means.len()]; - for i in 0..structure.len() as u32 { - let target = - crate::utils::k_means::k_means_lookup(&structure.means[i as usize], &means); - children[target].push(i); - } - let (means, children) = std::iter::zip(means, children) - .filter(|(_, x)| !x.is_empty()) - .unzip::<_, _, Vec<_>, Vec<_>>(); - result.push(Structure { means, children }); - } else { - let children = vec![Vec::new(); means.len()]; - result.push(Structure { means, children }); - } - } - result - } - fn extern_build( - vector_options: VectorOptions, - _opfamily: Opfamily, - external_build: VchordrqfscanExternalBuildOptions, - ) -> Vec { - use std::collections::BTreeMap; - let VchordrqfscanExternalBuildOptions { table } = external_build; - let mut parents = BTreeMap::new(); - let mut vectors = BTreeMap::new(); - pgrx::spi::Spi::connect(|client| { - use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; - use pgrx::pg_sys::panic::ErrorReportable; - use vector::VectorBorrowed; - let schema_query = "SELECT n.nspname::TEXT - FROM pg_catalog.pg_extension e - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace - WHERE e.extname = 'vector';"; - let pgvector_schema: String = client - .select(schema_query, None, None) - .unwrap_or_report() - .first() - .get_by_name("nspname") - .expect("external build: cannot get schema of pgvector") - .expect("external build: cannot get schema of pgvector"); - let dump_query = - format!("SELECT id, parent, vector::{pgvector_schema}.vector FROM {table};"); - let centroids = client.select(&dump_query, None, None).unwrap_or_report(); - for row in centroids { - let id: Option = row.get_by_name("id").unwrap(); - let parent: Option = row.get_by_name("parent").unwrap(); - let vector: Option = row.get_by_name("vector").unwrap(); - let id = id.expect("external build: id could not be NULL"); - let vector = vector.expect("external build: vector could not be NULL"); - let pop = parents.insert(id, parent); - if pop.is_some() { - pgrx::error!( - "external build: there are at least two lines have same id, id = {id}" - ); - } - if vector_options.dims != vector.as_borrowed().dims() { - pgrx::error!("external build: incorrect dimension, id = {id}"); - } - vectors.insert(id, crate::projection::project(vector.as_borrowed().slice())); - } - }); - if parents.len() >= 2 && parents.values().all(|x| x.is_none()) { - // if there are more than one vertexs and no edges, - // assume there is an implicit root - let n = parents.len(); - let mut result = Vec::new(); - result.push(Structure { - means: vectors.values().cloned().collect::>(), - children: vec![Vec::new(); n], - }); - result.push(Structure { - means: vec![{ - // compute the vector on root, without normalizing it - let mut sum = vec![0.0f32; vector_options.dims as _]; - for vector in vectors.values() { - f32::vector_add_inplace(&mut sum, vector); - } - f32::vector_mul_scalar_inplace(&mut sum, 1.0 / n as f32); - sum - }], - children: vec![(0..n as u32).collect()], - }); - return result; - } - let mut children = parents - .keys() - .map(|x| (*x, Vec::new())) - .collect::>(); - let mut root = None; - for (&id, &parent) in parents.iter() { - if let Some(parent) = parent { - if let Some(parent) = children.get_mut(&parent) { - parent.push(id); - } else { - pgrx::error!( - "external build: parent does not exist, id = {id}, parent = {parent}" - ); - } - } else { - if let Some(root) = root { - pgrx::error!("external build: two root, id = {root}, id = {id}"); - } else { - root = Some(id); - } - } - } - let Some(root) = root else { - pgrx::error!("external build: there are no root"); - }; - let mut heights = BTreeMap::<_, _>::new(); - fn dfs_for_heights( - heights: &mut BTreeMap>, - children: &BTreeMap>, - u: i32, - ) { - if heights.contains_key(&u) { - pgrx::error!("external build: detect a cycle, id = {u}"); - } - heights.insert(u, None); - let mut height = None; - for &v in children[&u].iter() { - dfs_for_heights(heights, children, v); - let new = heights[&v].unwrap() + 1; - if let Some(height) = height { - if height != new { - pgrx::error!("external build: two heights, id = {u}"); - } - } else { - height = Some(new); - } - } - if height.is_none() { - height = Some(1); - } - heights.insert(u, height); - } - dfs_for_heights(&mut heights, &children, root); - let heights = heights - .into_iter() - .map(|(k, v)| (k, v.expect("not a connected graph"))) - .collect::>(); - if !(1..=8).contains(&(heights[&root] - 1)) { - pgrx::error!( - "external build: unexpected tree height, height = {}", - heights[&root] - ); - } - let mut cursors = vec![0_u32; 1 + heights[&root] as usize]; - let mut labels = BTreeMap::new(); - for id in parents.keys().copied() { - let height = heights[&id]; - let cursor = cursors[height as usize]; - labels.insert(id, (height, cursor)); - cursors[height as usize] += 1; - } - fn extract( - height: u32, - labels: &BTreeMap, - vectors: &BTreeMap>, - children: &BTreeMap>, - ) -> (Vec>, Vec>) { - labels - .iter() - .filter(|(_, &(h, _))| h == height) - .map(|(id, _)| { - ( - vectors[id].clone(), - children[id].iter().map(|id| labels[id].1).collect(), - ) - }) - .unzip() - } - let mut result = Vec::new(); - for height in 1..=heights[&root] { - let (means, children) = extract(height, &labels, &vectors, &children); - result.push(Structure { means, children }); - } - result - } -} - -struct Tape<'a: 'b, 'b, T, R: 'b + RelationWrite> { - relation: &'a R, - head: R::WriteGuard<'b>, - first: u32, - tracking_freespace: bool, - _phantom: PhantomData T>, -} - -impl<'a: 'b, 'b, T, R: 'b + RelationWrite> Tape<'a, 'b, T, R> { - fn create(relation: &'a R, tracking_freespace: bool) -> Self { - let mut head = relation.extend(tracking_freespace); - head.get_opaque_mut().skip = head.id(); - let first = head.id(); - Self { - relation, - head, - first, - tracking_freespace, - _phantom: PhantomData, - } - } - fn first(&self) -> u32 { - self.first - } -} - -impl<'a: 'b, 'b, T, R: 'b + RelationWrite> Tape<'a, 'b, T, R> -where - T: rkyv::Serialize>, -{ - fn push(&mut self, x: &T) -> (u32, u16) { - let bytes = rkyv::to_bytes(x).expect("failed to serialize"); - if let Some(i) = self.head.alloc(&bytes) { - (self.head.id(), i) - } else { - let next = self.relation.extend(self.tracking_freespace); - self.head.get_opaque_mut().next = next.id(); - self.head = next; - if let Some(i) = self.head.alloc(&bytes) { - (self.head.id(), i) - } else { - panic!("tuple is too large to fit in a fresh page") - } - } - } -} diff --git a/src/vchordrqfscan/algorithm/insert.rs b/src/vchordrqfscan/algorithm/insert.rs deleted file mode 100644 index cb30577..0000000 --- a/src/vchordrqfscan/algorithm/insert.rs +++ /dev/null @@ -1,299 +0,0 @@ -use crate::vchordrqfscan::algorithm::rabitq; -use crate::vchordrqfscan::algorithm::rabitq::fscan_process_lowerbound; -use crate::vchordrqfscan::algorithm::tuples::*; -use crate::vchordrqfscan::types::DistanceKind; -use crate::vchordrqfscan::types::distance; -use algorithm::{Page, PageGuard, RelationWrite}; -use always_equal::AlwaysEqual; -use distance::Distance; -use simd::Floating; -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::num::NonZeroU64; - -pub fn insert( - relation: impl RelationWrite + Clone, - payload: NonZeroU64, - vector: Vec, - distance_kind: DistanceKind, - in_building: bool, -) { - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let dims = meta_tuple.dims; - assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); - let vector = crate::projection::project(&vector); - let is_residual = meta_tuple.is_residual; - let default_lut = if !is_residual { - Some(rabitq::fscan_preprocess(&vector)) - } else { - None - }; - let h0_vector = { - let tuple = rkyv::to_bytes::<_, 8192>(&VectorTuple { - vector: vector.clone(), - payload: Some(payload), - }) - .unwrap(); - append( - relation.clone(), - meta_tuple.vectors_first, - &tuple, - true, - true, - true, - ) - }; - let h0_payload = payload; - let mut list = ( - meta_tuple.first, - if is_residual { - let vector_guard = relation.read(meta_tuple.mean.0); - let vector_tuple = vector_guard - .get(meta_tuple.mean.1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - Some(vector_tuple.vector.to_vec()) - } else { - None - }, - ); - let make_list = |list: (u32, Option>)| { - let mut results = Vec::new(); - { - let lut = if is_residual { - &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) - } else { - default_lut.as_ref().unwrap() - }; - let mut current = list.0; - while current != u32::MAX { - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( - distance_kind, - dims, - lut, - ( - &h1_tuple.dis_u_2, - &h1_tuple.factor_ppc, - &h1_tuple.factor_ip, - &h1_tuple.factor_err, - &h1_tuple.t, - ), - 1.9, - ); - for j in 0..32 { - if h1_tuple.mask[j] { - results.push(( - Reverse(lowerbounds[j]), - AlwaysEqual(h1_tuple.mean[j]), - AlwaysEqual(h1_tuple.first[j]), - )); - } - } - } - current = h1_guard.get_opaque().next; - } - } - let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); - { - while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); - let vector_guard = relation.read(mean.0); - let vector_tuple = vector_guard - .get(mean.1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let dis_u = distance(distance_kind, &vector, &vector_tuple.vector); - cache.push(( - Reverse(dis_u), - AlwaysEqual(first), - AlwaysEqual(if is_residual { - Some(vector_tuple.vector.to_vec()) - } else { - None - }), - )); - } - let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop().unwrap(); - (first, mean) - } - }; - for _ in (1..meta_tuple.height_of_root).rev() { - list = make_list(list); - } - let code = if is_residual { - rabitq::code(dims, &f32::vector_sub(&vector, list.1.as_ref().unwrap())) - } else { - rabitq::code(dims, &vector) - }; - let dummy = rkyv::to_bytes::<_, 8192>(&Height0Tuple { - mask: [false; 32], - mean: [(0, 0); 32], - payload: [NonZeroU64::MIN; 32], - dis_u_2: [0.0f32; 32], - factor_ppc: [0.0f32; 32], - factor_ip: [0.0f32; 32], - factor_err: [0.0f32; 32], - t: vec![0; (dims.div_ceil(4) * 16) as usize], - }) - .unwrap(); - append_by_update( - relation.clone(), - list.0, - &dummy, - in_building, - in_building, - |bytes| { - let t = rkyv::check_archived_root::(bytes).expect("data corruption"); - t.mask.iter().any(|x| *x) - }, - |bytes| put(bytes, dims, &code, h0_vector, h0_payload), - ); -} - -fn append( - relation: impl RelationWrite, - first: u32, - tuple: &[u8], - tracking_freespace: bool, - skipping_traversal: bool, - updating_skip: bool, -) -> (u32, u16) { - if tracking_freespace { - if let Some(mut write) = relation.search(tuple.len()) { - let i = write.alloc(tuple).unwrap(); - return (write.id(), i); - } - } - assert!(first != u32::MAX); - let mut current = first; - loop { - let read = relation.read(current); - if read.freespace() as usize >= tuple.len() || read.get_opaque().next == u32::MAX { - drop(read); - let mut write = relation.write(current, tracking_freespace); - if let Some(i) = write.alloc(tuple) { - return (current, i); - } - if write.get_opaque().next == u32::MAX { - let mut extend = relation.extend(tracking_freespace); - write.get_opaque_mut().next = extend.id(); - drop(write); - if let Some(i) = extend.alloc(tuple) { - let result = (extend.id(), i); - drop(extend); - if updating_skip { - let mut past = relation.write(first, tracking_freespace); - let skip = &mut past.get_opaque_mut().skip; - assert!(*skip != u32::MAX); - *skip = std::cmp::max(*skip, result.0); - } - return result; - } else { - panic!("a tuple cannot even be fit in a fresh page"); - } - } - if skipping_traversal && current == first && write.get_opaque().skip != first { - current = write.get_opaque().skip; - } else { - current = write.get_opaque().next; - } - } else { - if skipping_traversal && current == first && read.get_opaque().skip != first { - current = read.get_opaque().skip; - } else { - current = read.get_opaque().next; - } - } - } -} - -fn append_by_update( - relation: impl RelationWrite, - first: u32, - tuple: &[u8], - skipping_traversal: bool, - updating_skip: bool, - can_update: impl Fn(&[u8]) -> bool, - mut update: impl FnMut(&mut [u8]) -> bool, -) { - assert!(first != u32::MAX); - let mut current = first; - loop { - let read = relation.read(current); - let flag = 'flag: { - for i in 1..=read.len() { - if can_update(read.get(i).expect("data corruption")) { - break 'flag true; - } - } - if read.freespace() as usize >= tuple.len() { - break 'flag true; - } - if read.get_opaque().next == u32::MAX { - break 'flag true; - } - false - }; - if flag { - drop(read); - let mut write = relation.write(current, false); - for i in 1..=write.len() { - if update(write.get_mut(i).expect("data corruption")) { - return; - } - } - if let Some(i) = write.alloc(tuple) { - if update(write.get_mut(i).expect("data corruption")) { - return; - } - panic!("an update fails on a fresh tuple"); - } - if write.get_opaque().next == u32::MAX { - let mut extend = relation.extend(false); - write.get_opaque_mut().next = extend.id(); - drop(write); - if let Some(i) = extend.alloc(tuple) { - if update(extend.get_mut(i).expect("data corruption")) { - let id = extend.id(); - drop(extend); - if updating_skip { - let mut past = relation.write(first, false); - let skip = &mut past.get_opaque_mut().skip; - assert!(*skip != u32::MAX); - *skip = std::cmp::max(*skip, id); - } - return; - } - panic!("an update fails on a fresh tuple"); - } - panic!("a tuple cannot even be fit in a fresh page"); - } - if skipping_traversal && current == first && write.get_opaque().skip != first { - current = write.get_opaque().skip; - } else { - current = write.get_opaque().next; - } - } else { - if skipping_traversal && current == first && read.get_opaque().skip != first { - current = read.get_opaque().skip; - } else { - current = read.get_opaque().next; - } - } - } -} diff --git a/src/vchordrqfscan/algorithm/mod.rs b/src/vchordrqfscan/algorithm/mod.rs deleted file mode 100644 index 448d919..0000000 --- a/src/vchordrqfscan/algorithm/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod build; -pub mod insert; -pub mod prewarm; -pub mod rabitq; -pub mod scan; -pub mod tuples; -pub mod vacuum; diff --git a/src/vchordrqfscan/algorithm/prewarm.rs b/src/vchordrqfscan/algorithm/prewarm.rs deleted file mode 100644 index c8500d4..0000000 --- a/src/vchordrqfscan/algorithm/prewarm.rs +++ /dev/null @@ -1,102 +0,0 @@ -use crate::vchordrqfscan::algorithm::tuples::*; -use algorithm::{Page, RelationRead}; -use std::fmt::Write; - -pub fn prewarm(relation: impl RelationRead, height: i32) -> String { - let mut message = String::new(); - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - writeln!(message, "height of root: {}", meta_tuple.height_of_root).unwrap(); - let prewarm_max_height = if height < 0 { 0 } else { height as u32 }; - if prewarm_max_height > meta_tuple.height_of_root { - return message; - } - let mut lists = { - let mut results = Vec::new(); - let counter = 1_usize; - { - let vector_guard = relation.read(meta_tuple.mean.0); - let vector_tuple = vector_guard - .get(meta_tuple.mean.1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let _ = vector_tuple; - results.push(meta_tuple.first); - } - writeln!(message, "number of tuples: {}", results.len()).unwrap(); - writeln!(message, "number of pages: {}", counter).unwrap(); - results - }; - let mut make_lists = |lists| { - let mut counter = 0_usize; - let mut results = Vec::new(); - for list in lists { - let mut current = list; - while current != u32::MAX { - counter += 1; - pgrx::check_for_interrupts!(); - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - for j in 0..32 { - if h1_tuple.mask[j] { - results.push(h1_tuple.first[j]); - let mean = h1_tuple.mean[j]; - let vector_guard = relation.read(mean.0); - let vector_tuple = vector_guard - .get(mean.1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let _ = vector_tuple; - } - } - } - current = h1_guard.get_opaque().next; - } - } - writeln!(message, "number of tuples: {}", results.len()).unwrap(); - writeln!(message, "number of pages: {}", counter).unwrap(); - results - }; - for _ in (std::cmp::max(1, prewarm_max_height)..meta_tuple.height_of_root).rev() { - lists = make_lists(lists); - } - if prewarm_max_height == 0 { - let mut counter = 0_usize; - let mut results = Vec::new(); - for list in lists { - let mut current = list; - while current != u32::MAX { - counter += 1; - pgrx::check_for_interrupts!(); - let h0_guard = relation.read(current); - for i in 1..=h0_guard.len() { - let h0_tuple = h0_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - for j in 0..32 { - if h0_tuple.mask[j] { - results.push(()); - } - } - } - current = h0_guard.get_opaque().next; - } - } - writeln!(message, "number of tuples: {}", results.len()).unwrap(); - writeln!(message, "number of pages: {}", counter).unwrap(); - } - message -} diff --git a/src/vchordrqfscan/algorithm/rabitq.rs b/src/vchordrqfscan/algorithm/rabitq.rs deleted file mode 100644 index 707d81c..0000000 --- a/src/vchordrqfscan/algorithm/rabitq.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::vchordrqfscan::types::DistanceKind; -use distance::Distance; - -pub use rabitq::block::Code; -pub use rabitq::block::code; -pub use rabitq::block::dummy_code; -pub use rabitq::block::fscan_preprocess; -pub use rabitq::block::pack_codes; -pub use rabitq::block::{fscan_process_lowerbound_dot, fscan_process_lowerbound_l2}; - -pub fn fscan_process_lowerbound( - distance_kind: DistanceKind, - dims: u32, - lut: &(f32, f32, f32, f32, Vec), - code: (&[f32; 32], &[f32; 32], &[f32; 32], &[f32; 32], &[u8]), - epsilon: f32, -) -> [Distance; 32] { - match distance_kind { - DistanceKind::L2 => fscan_process_lowerbound_l2(dims, lut, code, epsilon), - DistanceKind::Dot => fscan_process_lowerbound_dot(dims, lut, code, epsilon), - } -} diff --git a/src/vchordrqfscan/algorithm/scan.rs b/src/vchordrqfscan/algorithm/scan.rs deleted file mode 100644 index 5546af9..0000000 --- a/src/vchordrqfscan/algorithm/scan.rs +++ /dev/null @@ -1,193 +0,0 @@ -use crate::vchordrqfscan::algorithm::rabitq; -use crate::vchordrqfscan::algorithm::rabitq::fscan_process_lowerbound; -use crate::vchordrqfscan::algorithm::tuples::*; -use crate::vchordrqfscan::types::DistanceKind; -use crate::vchordrqfscan::types::distance; -use algorithm::{Page, RelationWrite}; -use always_equal::AlwaysEqual; -use distance::Distance; -use simd::Floating; -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::num::NonZeroU64; - -pub fn scan( - relation: impl RelationWrite + Clone, - vector: Vec, - distance_kind: DistanceKind, - probes: Vec, - epsilon: f32, -) -> impl Iterator { - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let dims = meta_tuple.dims; - let height_of_root = meta_tuple.height_of_root; - assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); - assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes"); - let vector = crate::projection::project(&vector); - let is_residual = meta_tuple.is_residual; - let default_lut = if !is_residual { - Some(rabitq::fscan_preprocess(&vector)) - } else { - None - }; - let mut lists: Vec<_> = vec![( - meta_tuple.first, - if is_residual { - let vector_guard = relation.read(meta_tuple.mean.0); - let vector_tuple = vector_guard - .get(meta_tuple.mean.1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - Some(vector_tuple.vector.to_vec()) - } else { - None - }, - )]; - let make_lists = |lists: Vec<(u32, Option>)>, probes| { - let mut results = Vec::new(); - for list in lists { - let lut = if is_residual { - &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) - } else { - default_lut.as_ref().unwrap() - }; - let mut current = list.0; - while current != u32::MAX { - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( - distance_kind, - dims, - lut, - ( - &h1_tuple.dis_u_2, - &h1_tuple.factor_ppc, - &h1_tuple.factor_ip, - &h1_tuple.factor_err, - &h1_tuple.t, - ), - epsilon, - ); - for j in 0..32 { - if h1_tuple.mask[j] { - results.push(( - Reverse(lowerbounds[j]), - AlwaysEqual(h1_tuple.mean[j]), - AlwaysEqual(h1_tuple.first[j]), - )); - } - } - } - current = h1_guard.get_opaque().next; - } - } - let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); - std::iter::from_fn(|| { - while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); - let vector_guard = relation.read(mean.0); - let vector_tuple = vector_guard - .get(mean.1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let dis_u = distance(distance_kind, &vector, &vector_tuple.vector); - cache.push(( - Reverse(dis_u), - AlwaysEqual(first), - AlwaysEqual(if is_residual { - Some(vector_tuple.vector.to_vec()) - } else { - None - }), - )); - } - let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop()?; - Some((first, mean)) - }) - .take(probes as usize) - .collect() - }; - for i in (1..meta_tuple.height_of_root).rev() { - lists = make_lists(lists, probes[i as usize - 1]); - } - drop(meta_guard); - { - let mut results = Vec::new(); - for list in lists { - let lut = if is_residual { - &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) - } else { - default_lut.as_ref().unwrap() - }; - let mut current = list.0; - while current != u32::MAX { - let h0_guard = relation.read(current); - for i in 1..=h0_guard.len() { - let h0_tuple = h0_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let lowerbounds = fscan_process_lowerbound( - distance_kind, - dims, - lut, - ( - &h0_tuple.dis_u_2, - &h0_tuple.factor_ppc, - &h0_tuple.factor_ip, - &h0_tuple.factor_err, - &h0_tuple.t, - ), - epsilon, - ); - for j in 0..32 { - if h0_tuple.mask[j] { - results.push(( - Reverse(lowerbounds[j]), - AlwaysEqual(h0_tuple.mean[j]), - AlwaysEqual(h0_tuple.payload[j]), - )); - } - } - } - current = h0_guard.get_opaque().next; - } - } - let mut heap = BinaryHeap::from(results); - let mut cache = BinaryHeap::<(Reverse, _)>::new(); - std::iter::from_fn(move || { - while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { - let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap(); - let vector_guard = relation.read(mean.0); - let Some(vector_tuple) = vector_guard.get(mean.1) else { - // fails consistency check - continue; - }; - let vector_tuple = rkyv::check_archived_root::(vector_tuple) - .expect("data corruption"); - if vector_tuple.payload != Some(pay_u) { - // fails consistency check - continue; - } - let dis_u = distance(distance_kind, &vector, &vector_tuple.vector); - cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); - } - let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; - Some((dis_u, pay_u)) - }) - } -} diff --git a/src/vchordrqfscan/algorithm/tuples.rs b/src/vchordrqfscan/algorithm/tuples.rs deleted file mode 100644 index 7d1c97a..0000000 --- a/src/vchordrqfscan/algorithm/tuples.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::num::NonZeroU64; - -use crate::vchordrqfscan::algorithm::rabitq; -use rkyv::{Archive, Deserialize, Serialize}; - -#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] -#[archive(check_bytes)] -pub struct MetaTuple { - pub dims: u32, - pub height_of_root: u32, - pub is_residual: bool, - pub vectors_first: u32, - // raw vector - pub mean: (u32, u16), - // for meta tuple, it's pointers to next level - pub first: u32, -} - -#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] -#[archive(check_bytes)] -pub struct VectorTuple { - pub vector: Vec, - // this field is saved only for vacuum - pub payload: Option, -} - -#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] -#[archive(check_bytes)] -pub struct Height1Tuple { - pub mask: [bool; 32], - // raw vector - pub mean: [(u32, u16); 32], - // for height 1 tuple, it's pointers to next level - pub first: [u32; 32], - // RaBitQ algorithm - pub dis_u_2: [f32; 32], - pub factor_ppc: [f32; 32], - pub factor_ip: [f32; 32], - pub factor_err: [f32; 32], - pub t: Vec, -} - -#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] -#[archive(check_bytes)] -pub struct Height0Tuple { - pub mask: [bool; 32], - // raw vector - pub mean: [(u32, u16); 32], - // for height 0 tuple, it's pointers to heap relation - pub payload: [NonZeroU64; 32], - // RaBitQ algorithm - pub dis_u_2: [f32; 32], - pub factor_ppc: [f32; 32], - pub factor_ip: [f32; 32], - pub factor_err: [f32; 32], - pub t: Vec, -} - -pub fn put( - bytes: &mut [u8], - dims: u32, - code: &rabitq::Code, - vector: (u32, u16), - payload: NonZeroU64, -) -> bool { - // todo: use mutable api - let mut x = rkyv::from_bytes::(bytes).expect("data corruption"); - for j in 0..32 { - if !x.mask[j] { - x.mean[j] = vector; - x.payload[j] = payload; - x.mask[j] = true; - x.dis_u_2[j] = code.dis_u_2; - x.factor_ppc[j] = code.factor_ppc; - x.factor_ip[j] = code.factor_ip; - x.factor_err[j] = code.factor_err; - let width = dims.div_ceil(4) as usize; - let table = [ - (0, 0), - (2, 0), - (4, 0), - (6, 0), - (8, 0), - (10, 0), - (12, 0), - (14, 0), - (1, 0), - (3, 0), - (5, 0), - (7, 0), - (9, 0), - (11, 0), - (13, 0), - (15, 0), - (0, 1), - (2, 1), - (4, 1), - (6, 1), - (8, 1), - (10, 1), - (12, 1), - (14, 1), - (1, 1), - (3, 1), - (5, 1), - (7, 1), - (9, 1), - (11, 1), - (13, 1), - (15, 1), - ]; - let pos = table[j].0; - let mask = match table[j].1 { - 0 => 0xf0, - 1 => 0x0f, - _ => unreachable!(), - }; - let shift = match table[j].1 { - 0 => 0, - 1 => 4, - _ => unreachable!(), - }; - let mut buffer = vec![0u8; width]; - for j in 0..width { - let b0 = code.signs.get(4 * j + 0).copied().unwrap_or_default(); - let b1 = code.signs.get(4 * j + 1).copied().unwrap_or_default(); - let b2 = code.signs.get(4 * j + 2).copied().unwrap_or_default(); - let b3 = code.signs.get(4 * j + 3).copied().unwrap_or_default(); - buffer[j] = b0 | b1 << 1 | b2 << 2 | b3 << 3; - } - for j in 0..width { - x.t[16 * j + pos] &= mask; - x.t[16 * j + pos] |= buffer[j] << shift; - } - bytes.copy_from_slice(&rkyv::to_bytes::<_, 8192>(&x).unwrap()); - return true; - } - } - false -} diff --git a/src/vchordrqfscan/algorithm/vacuum.rs b/src/vchordrqfscan/algorithm/vacuum.rs deleted file mode 100644 index 7773ed2..0000000 --- a/src/vchordrqfscan/algorithm/vacuum.rs +++ /dev/null @@ -1,139 +0,0 @@ -use crate::vchordrqfscan::algorithm::tuples::VectorTuple; -use crate::vchordrqfscan::algorithm::tuples::*; -use algorithm::{Page, RelationWrite}; -use std::num::NonZeroU64; - -pub fn vacuum( - relation: impl RelationWrite, - delay: impl Fn(), - callback: impl Fn(NonZeroU64) -> bool, -) { - // step 1: vacuum height_0_tuple - { - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let mut firsts = vec![meta_tuple.first]; - let make_firsts = |firsts| { - let mut results = Vec::new(); - for first in firsts { - let mut current = first; - while current != u32::MAX { - let h1_guard = relation.read(current); - for i in 1..=h1_guard.len() { - let h1_tuple = h1_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - for j in 0..32 { - if h1_tuple.mask[j] { - results.push(h1_tuple.first[j]); - } - } - } - current = h1_guard.get_opaque().next; - } - } - results - }; - for _ in (1..meta_tuple.height_of_root).rev() { - firsts = make_firsts(firsts); - } - for first in firsts { - let mut current = first; - while current != u32::MAX { - delay(); - let mut h0_guard = relation.write(current, false); - for i in 1..=h0_guard.len() { - let h0_tuple = h0_guard - .get(i) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - let flag = 'flag: { - for j in 0..32 { - if h0_tuple.mask[j] && callback(h0_tuple.payload[j]) { - break 'flag true; - } - } - false - }; - if flag { - // todo: use mutable API - let mut temp = h0_guard - .get(i) - .map(rkyv::from_bytes::) - .expect("data corruption") - .expect("data corruption"); - for j in 0..32 { - if temp.mask[j] && callback(temp.payload[j]) { - temp.mask[j] = false; - } - } - let temp = rkyv::to_bytes::<_, 8192>(&temp).expect("failed to serialize"); - h0_guard - .get_mut(i) - .expect("data corruption") - .copy_from_slice(&temp); - } - } - // todo: cross-tuple vacuum so that we can skip a tuple - current = h0_guard.get_opaque().next; - } - } - } - // step 2: vacuum vector_tuple - { - let mut current = { - let meta_guard = relation.read(0); - let meta_tuple = meta_guard - .get(1) - .map(rkyv::check_archived_root::) - .expect("data corruption") - .expect("data corruption"); - meta_tuple.vectors_first - }; - while current != u32::MAX { - delay(); - let read = relation.read(current); - let flag = 'flag: { - for i in 1..=read.len() { - let Some(vector_tuple) = read.get(i) else { - continue; - }; - let vector_tuple = rkyv::check_archived_root::(vector_tuple) - .expect("data corruption"); - if let Some(payload) = vector_tuple.payload.as_ref().copied() { - if callback(payload) { - break 'flag true; - } - } - } - false - }; - if flag { - drop(read); - let mut write = relation.write(current, true); - for i in 1..=write.len() { - let Some(vector_tuple) = write.get(i) else { - continue; - }; - let vector_tuple = rkyv::check_archived_root::(vector_tuple) - .expect("data corruption"); - if let Some(payload) = vector_tuple.payload.as_ref().copied() { - if callback(payload) { - write.free(i); - } - } - } - current = write.get_opaque().next; - } else { - current = read.get_opaque().next; - } - } - } -} diff --git a/src/vchordrqfscan/gucs/executing.rs b/src/vchordrqfscan/gucs/executing.rs deleted file mode 100644 index ba204b9..0000000 --- a/src/vchordrqfscan/gucs/executing.rs +++ /dev/null @@ -1,72 +0,0 @@ -use pgrx::guc::{GucContext, GucFlags, GucRegistry, GucSetting}; -use std::ffi::CStr; - -static PROBES: GucSetting> = GucSetting::>::new(Some(c"10")); -static EPSILON: GucSetting = GucSetting::::new(1.9); -static MAX_SCAN_TUPLES: GucSetting = GucSetting::::new(-1); - -pub unsafe fn init() { - GucRegistry::define_string_guc( - "vchordrqfscan.probes", - "`probes` argument of vchordrqfscan.", - "`probes` argument of vchordrqfscan.", - &PROBES, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_float_guc( - "vchordrqfscan.epsilon", - "`epsilon` argument of vchordrqfscan.", - "`epsilon` argument of vchordrqfscan.", - &EPSILON, - 0.0, - 4.0, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_int_guc( - "vchordrqfscan.max_scan_tuples", - "`max_scan_tuples` argument of vchordrqfscan.", - "`max_scan_tuples` argument of vchordrqfscan.", - &MAX_SCAN_TUPLES, - -1, - u16::MAX as _, - GucContext::Userset, - GucFlags::default(), - ); -} - -pub fn probes() -> Vec { - match PROBES.get() { - None => Vec::new(), - Some(probes) => { - let mut result = Vec::new(); - let mut current = None; - for &c in probes.to_bytes() { - match c { - b' ' => continue, - b',' => result.push(current.take().expect("empty probes")), - b'0'..=b'9' => { - if let Some(x) = current.as_mut() { - *x = *x * 10 + (c - b'0') as u32; - } else { - current = Some((c - b'0') as u32); - } - } - c => pgrx::error!("unknown character in probes: ASCII = {c}"), - } - } - result.push(current.take().expect("empty probes")); - result - } - } -} - -pub fn epsilon() -> f32 { - EPSILON.get() as f32 -} - -pub fn max_scan_tuples() -> Option { - let x = MAX_SCAN_TUPLES.get(); - if x < 0 { None } else { Some(x as u32) } -} diff --git a/src/vchordrqfscan/gucs/mod.rs b/src/vchordrqfscan/gucs/mod.rs deleted file mode 100644 index 48cc060..0000000 --- a/src/vchordrqfscan/gucs/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub mod executing; -pub mod prewarm; - -pub unsafe fn init() { - unsafe { - executing::init(); - prewarm::init(); - prewarm::prewarm(); - #[cfg(any(feature = "pg13", feature = "pg14"))] - pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchordrqfscan".as_ptr()); - #[cfg(any(feature = "pg15", feature = "pg16", feature = "pg17"))] - pgrx::pg_sys::MarkGUCPrefixReserved(c"vchordrqfscan".as_ptr()); - } -} diff --git a/src/vchordrqfscan/gucs/prewarm.rs b/src/vchordrqfscan/gucs/prewarm.rs deleted file mode 100644 index ae9180a..0000000 --- a/src/vchordrqfscan/gucs/prewarm.rs +++ /dev/null @@ -1,32 +0,0 @@ -use pgrx::guc::{GucContext, GucFlags, GucRegistry, GucSetting}; -use std::ffi::CStr; - -static PREWARM_DIM: GucSetting> = - GucSetting::>::new(Some(c"64,128,256,384,512,768,1024,1536")); - -pub unsafe fn init() { - GucRegistry::define_string_guc( - "vchordrqfscan.prewarm_dim", - "prewarm_dim when the extension is loading.", - "prewarm_dim when the extension is loading.", - &PREWARM_DIM, - GucContext::Userset, - GucFlags::default(), - ); -} - -pub fn prewarm() { - if let Some(prewarm_dim) = PREWARM_DIM.get() { - if let Ok(prewarm_dim) = prewarm_dim.to_str() { - for dim in prewarm_dim.split(',') { - if let Ok(dim) = dim.trim().parse::() { - crate::projection::prewarm(dim as _); - } else { - pgrx::warning!("{dim:?} is not a valid integer"); - } - } - } else { - pgrx::warning!("vchordrqfscan.prewarm_dim is not a valid UTF-8 string"); - } - } -} diff --git a/src/vchordrqfscan/index/am.rs b/src/vchordrqfscan/index/am.rs deleted file mode 100644 index 339b315..0000000 --- a/src/vchordrqfscan/index/am.rs +++ /dev/null @@ -1,856 +0,0 @@ -use crate::postgres::PostgresRelation; -use crate::vchordrqfscan::algorithm; -use crate::vchordrqfscan::algorithm::build::{HeapRelation, Reporter}; -use crate::vchordrqfscan::index::am_options::{Opfamily, Reloption}; -use crate::vchordrqfscan::index::am_scan::Scanner; -use crate::vchordrqfscan::index::utils::{ctid_to_pointer, pointer_to_ctid}; -use crate::vchordrqfscan::index::{am_options, am_scan}; -use pgrx::datum::Internal; -use pgrx::pg_sys::Datum; -use std::num::NonZeroU64; - -static mut RELOPT_KIND_VCHORDRQFSCAN: pgrx::pg_sys::relopt_kind::Type = 0; - -pub unsafe fn init() { - unsafe { - (&raw mut RELOPT_KIND_VCHORDRQFSCAN).write(pgrx::pg_sys::add_reloption_kind()); - pgrx::pg_sys::add_string_reloption( - (&raw const RELOPT_KIND_VCHORDRQFSCAN).read(), - c"options".as_ptr(), - c"Vector index options, represented as a TOML string.".as_ptr(), - c"".as_ptr(), - None, - pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE, - ); - } -} - -#[pgrx::pg_extern(sql = "")] -fn _vchordrqfscan_amhandler(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { - type T = pgrx::pg_sys::IndexAmRoutine; - unsafe { - let index_am_routine = pgrx::pg_sys::palloc0(size_of::()) as *mut T; - index_am_routine.write(AM_HANDLER); - Internal::from(Some(Datum::from(index_am_routine))) - } -} - -const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = { - let mut am_routine = - unsafe { std::mem::MaybeUninit::::zeroed().assume_init() }; - - am_routine.type_ = pgrx::pg_sys::NodeTag::T_IndexAmRoutine; - - am_routine.amsupport = 1; - am_routine.amcanorderbyop = true; - - #[cfg(feature = "pg17")] - { - am_routine.amcanbuildparallel = true; - } - - // Index access methods that set `amoptionalkey` to `false` - // must index all tuples, even if the first column is `NULL`. - // However, PostgreSQL does not generate a path if there is no - // index clauses, even if there is a `ORDER BY` clause. - // So we have to set it to `true` and set costs of every path - // for vector index scans without `ORDER BY` clauses a large number - // and throw errors if someone really wants such a path. - am_routine.amoptionalkey = true; - - am_routine.amvalidate = Some(amvalidate); - am_routine.amoptions = Some(amoptions); - am_routine.amcostestimate = Some(amcostestimate); - - am_routine.ambuild = Some(ambuild); - am_routine.ambuildempty = Some(ambuildempty); - am_routine.aminsert = Some(aminsert); - am_routine.ambulkdelete = Some(ambulkdelete); - am_routine.amvacuumcleanup = Some(amvacuumcleanup); - - am_routine.ambeginscan = Some(ambeginscan); - am_routine.amrescan = Some(amrescan); - am_routine.amgettuple = Some(amgettuple); - am_routine.amendscan = Some(amendscan); - - am_routine -}; - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amvalidate(_opclass_oid: pgrx::pg_sys::Oid) -> bool { - true -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amoptions(reloptions: Datum, validate: bool) -> *mut pgrx::pg_sys::bytea { - let rdopts = unsafe { - pgrx::pg_sys::build_reloptions( - reloptions, - validate, - (&raw const RELOPT_KIND_VCHORDRQFSCAN).read(), - size_of::(), - Reloption::TAB.as_ptr(), - Reloption::TAB.len() as _, - ) - }; - rdopts as *mut pgrx::pg_sys::bytea -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amcostestimate( - _root: *mut pgrx::pg_sys::PlannerInfo, - path: *mut pgrx::pg_sys::IndexPath, - _loop_count: f64, - index_startup_cost: *mut pgrx::pg_sys::Cost, - index_total_cost: *mut pgrx::pg_sys::Cost, - index_selectivity: *mut pgrx::pg_sys::Selectivity, - index_correlation: *mut f64, - index_pages: *mut f64, -) { - unsafe { - if (*path).indexorderbys.is_null() && (*path).indexclauses.is_null() { - *index_startup_cost = f64::MAX; - *index_total_cost = f64::MAX; - *index_selectivity = 0.0; - *index_correlation = 0.0; - *index_pages = 0.0; - return; - } - *index_startup_cost = 0.0; - *index_total_cost = 0.0; - *index_selectivity = 1.0; - *index_correlation = 1.0; - *index_pages = 0.0; - } -} - -#[derive(Debug, Clone)] -struct PgReporter {} - -impl Reporter for PgReporter { - fn tuples_total(&mut self, tuples_total: u64) { - unsafe { - pgrx::pg_sys::pgstat_progress_update_param( - pgrx::pg_sys::PROGRESS_CREATEIDX_TUPLES_TOTAL as _, - tuples_total as _, - ); - } - } -} - -impl PgReporter { - fn tuples_done(&mut self, tuples_done: u64) { - unsafe { - pgrx::pg_sys::pgstat_progress_update_param( - pgrx::pg_sys::PROGRESS_CREATEIDX_TUPLES_DONE as _, - tuples_done as _, - ); - } - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn ambuild( - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - index_info: *mut pgrx::pg_sys::IndexInfo, -) -> *mut pgrx::pg_sys::IndexBuildResult { - use validator::Validate; - #[derive(Debug, Clone)] - pub struct Heap { - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - index_info: *mut pgrx::pg_sys::IndexInfo, - opfamily: Opfamily, - } - impl HeapRelation for Heap { - fn traverse(&self, progress: bool, callback: F) - where - F: FnMut((NonZeroU64, Vec)), - { - pub struct State<'a, F> { - pub this: &'a Heap, - pub callback: F, - } - #[pgrx::pg_guard] - unsafe extern "C" fn call( - _index: pgrx::pg_sys::Relation, - ctid: pgrx::pg_sys::ItemPointer, - values: *mut Datum, - is_null: *mut bool, - _tuple_is_alive: bool, - state: *mut core::ffi::c_void, - ) where - F: FnMut((NonZeroU64, Vec)), - { - use crate::vchordrqfscan::types::OwnedVector; - let state = unsafe { &mut *state.cast::>() }; - let opfamily = state.this.opfamily; - let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; - let pointer = unsafe { ctid_to_pointer(ctid.read()) }; - if let Some(vector) = vector { - let vector = match vector { - OwnedVector::Vecf32(x) => x, - }; - (state.callback)((pointer, vector.into_vec())); - } - } - let table_am = unsafe { &*(*self.heap).rd_tableam }; - let mut state = State { - this: self, - callback, - }; - unsafe { - table_am.index_build_range_scan.unwrap()( - self.heap, - self.index, - self.index_info, - true, - false, - progress, - 0, - pgrx::pg_sys::InvalidBlockNumber, - Some(call::), - (&mut state) as *mut State as *mut _, - std::ptr::null_mut(), - ); - } - } - - fn opfamily(&self) -> Opfamily { - self.opfamily - } - } - let (vector_options, vchordrqfscan_options) = unsafe { am_options::options(index) }; - if let Err(errors) = Validate::validate(&vector_options) { - pgrx::error!("error while validating options: {}", errors); - } - if vector_options.dims == 0 { - pgrx::error!("error while validating options: dimension cannot be 0"); - } - if vector_options.dims > 1600 { - pgrx::error!("error while validating options: dimension is too large"); - } - if let Err(errors) = Validate::validate(&vchordrqfscan_options) { - pgrx::error!("error while validating options: {}", errors); - } - let opfamily = unsafe { am_options::opfamily(index) }; - let heap_relation = Heap { - heap, - index, - index_info, - opfamily, - }; - let mut reporter = PgReporter {}; - let index_relation = unsafe { PostgresRelation::new(index) }; - algorithm::build::build( - vector_options, - vchordrqfscan_options, - heap_relation.clone(), - index_relation.clone(), - reporter.clone(), - ); - if let Some(leader) = - unsafe { VchordrqfscanLeader::enter(heap, index, (*index_info).ii_Concurrent) } - { - unsafe { - parallel_build( - index, - heap, - index_info, - leader.tablescandesc, - leader.vchordrqfscanshared, - Some(reporter), - ); - leader.wait(); - let nparticipants = leader.nparticipants; - loop { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*leader.vchordrqfscanshared).mutex); - if (*leader.vchordrqfscanshared).nparticipantsdone == nparticipants { - pgrx::pg_sys::SpinLockRelease(&raw mut (*leader.vchordrqfscanshared).mutex); - break; - } - pgrx::pg_sys::SpinLockRelease(&raw mut (*leader.vchordrqfscanshared).mutex); - pgrx::pg_sys::ConditionVariableSleep( - &raw mut (*leader.vchordrqfscanshared).workersdonecv, - pgrx::pg_sys::WaitEventIPC::WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN, - ); - } - pgrx::pg_sys::ConditionVariableCancelSleep(); - } - } else { - let mut indtuples = 0; - reporter.tuples_done(indtuples); - heap_relation.traverse(true, |(payload, vector)| { - algorithm::insert::insert( - index_relation.clone(), - payload, - vector, - opfamily.distance_kind(), - true, - ); - indtuples += 1; - reporter.tuples_done(indtuples); - }); - } - unsafe { pgrx::pgbox::PgBox::::alloc0().into_pg() } -} - -struct VchordrqfscanShared { - /* Immutable state */ - heaprelid: pgrx::pg_sys::Oid, - indexrelid: pgrx::pg_sys::Oid, - isconcurrent: bool, - - /* Worker progress */ - workersdonecv: pgrx::pg_sys::ConditionVariable, - - /* Mutex for mutable state */ - mutex: pgrx::pg_sys::slock_t, - - /* Mutable state */ - nparticipantsdone: i32, - indtuples: u64, -} - -fn is_mvcc_snapshot(snapshot: *mut pgrx::pg_sys::SnapshotData) -> bool { - matches!( - unsafe { (*snapshot).snapshot_type }, - pgrx::pg_sys::SnapshotType::SNAPSHOT_MVCC - | pgrx::pg_sys::SnapshotType::SNAPSHOT_HISTORIC_MVCC - ) -} - -struct VchordrqfscanLeader { - pcxt: *mut pgrx::pg_sys::ParallelContext, - nparticipants: i32, - vchordrqfscanshared: *mut VchordrqfscanShared, - tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData, - snapshot: pgrx::pg_sys::Snapshot, -} - -impl VchordrqfscanLeader { - pub unsafe fn enter( - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - isconcurrent: bool, - ) -> Option { - unsafe fn compute_parallel_workers( - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - ) -> i32 { - unsafe { - if pgrx::pg_sys::plan_create_index_workers((*heap).rd_id, (*index).rd_id) == 0 { - return 0; - } - if !(*heap).rd_options.is_null() { - let std_options = (*heap).rd_options.cast::(); - std::cmp::min( - (*std_options).parallel_workers, - pgrx::pg_sys::max_parallel_maintenance_workers, - ) - } else { - pgrx::pg_sys::max_parallel_maintenance_workers - } - } - } - - let request = unsafe { compute_parallel_workers(heap, index) }; - if request <= 0 { - return None; - } - - unsafe { - pgrx::pg_sys::EnterParallelMode(); - } - let pcxt = unsafe { - pgrx::pg_sys::CreateParallelContext( - c"vchord".as_ptr(), - c"vchordrqfscan_parallel_build_main".as_ptr(), - request, - ) - }; - - let snapshot = if isconcurrent { - unsafe { pgrx::pg_sys::RegisterSnapshot(pgrx::pg_sys::GetTransactionSnapshot()) } - } else { - &raw mut pgrx::pg_sys::SnapshotAnyData - }; - - fn estimate_chunk(e: &mut pgrx::pg_sys::shm_toc_estimator, x: usize) { - e.space_for_chunks += x.next_multiple_of(pgrx::pg_sys::ALIGNOF_BUFFER as _); - } - fn estimate_keys(e: &mut pgrx::pg_sys::shm_toc_estimator, x: usize) { - e.number_of_keys += x; - } - let est_tablescandesc = - unsafe { pgrx::pg_sys::table_parallelscan_estimate(heap, snapshot) }; - unsafe { - estimate_chunk(&mut (*pcxt).estimator, size_of::()); - estimate_keys(&mut (*pcxt).estimator, 1); - estimate_chunk(&mut (*pcxt).estimator, est_tablescandesc); - estimate_keys(&mut (*pcxt).estimator, 1); - } - - unsafe { - pgrx::pg_sys::InitializeParallelDSM(pcxt); - if (*pcxt).seg.is_null() { - if is_mvcc_snapshot(snapshot) { - pgrx::pg_sys::UnregisterSnapshot(snapshot); - } - pgrx::pg_sys::DestroyParallelContext(pcxt); - pgrx::pg_sys::ExitParallelMode(); - return None; - } - } - - let vchordrqfscanshared = unsafe { - let vchordrqfscanshared = - pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, size_of::()) - .cast::(); - vchordrqfscanshared.write(VchordrqfscanShared { - heaprelid: (*heap).rd_id, - indexrelid: (*index).rd_id, - isconcurrent, - workersdonecv: std::mem::zeroed(), - mutex: std::mem::zeroed(), - nparticipantsdone: 0, - indtuples: 0, - }); - pgrx::pg_sys::ConditionVariableInit(&raw mut (*vchordrqfscanshared).workersdonecv); - pgrx::pg_sys::SpinLockInit(&raw mut (*vchordrqfscanshared).mutex); - vchordrqfscanshared - }; - - let tablescandesc = unsafe { - let tablescandesc = pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, est_tablescandesc) - .cast::(); - pgrx::pg_sys::table_parallelscan_initialize(heap, tablescandesc, snapshot); - tablescandesc - }; - - unsafe { - pgrx::pg_sys::shm_toc_insert( - (*pcxt).toc, - 0xA000000000000001, - vchordrqfscanshared.cast(), - ); - pgrx::pg_sys::shm_toc_insert((*pcxt).toc, 0xA000000000000002, tablescandesc.cast()); - } - - unsafe { - pgrx::pg_sys::LaunchParallelWorkers(pcxt); - } - - let nworkers_launched = unsafe { (*pcxt).nworkers_launched }; - - unsafe { - if nworkers_launched == 0 { - pgrx::pg_sys::WaitForParallelWorkersToFinish(pcxt); - if is_mvcc_snapshot(snapshot) { - pgrx::pg_sys::UnregisterSnapshot(snapshot); - } - pgrx::pg_sys::DestroyParallelContext(pcxt); - pgrx::pg_sys::ExitParallelMode(); - return None; - } - } - - Some(Self { - pcxt, - nparticipants: nworkers_launched + 1, - vchordrqfscanshared, - tablescandesc, - snapshot, - }) - } - - pub fn wait(&self) { - unsafe { - pgrx::pg_sys::WaitForParallelWorkersToAttach(self.pcxt); - } - } -} - -impl Drop for VchordrqfscanLeader { - fn drop(&mut self) { - if !std::thread::panicking() { - unsafe { - pgrx::pg_sys::WaitForParallelWorkersToFinish(self.pcxt); - if is_mvcc_snapshot(self.snapshot) { - pgrx::pg_sys::UnregisterSnapshot(self.snapshot); - } - pgrx::pg_sys::DestroyParallelContext(self.pcxt); - pgrx::pg_sys::ExitParallelMode(); - } - } - } -} - -#[pgrx::pg_guard] -#[unsafe(no_mangle)] -pub unsafe extern "C" fn vchordrqfscan_parallel_build_main( - _seg: *mut pgrx::pg_sys::dsm_segment, - toc: *mut pgrx::pg_sys::shm_toc, -) { - let vchordrqfscanshared = unsafe { - pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000001, false).cast::() - }; - let tablescandesc = unsafe { - pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000002, false) - .cast::() - }; - let heap_lockmode; - let index_lockmode; - if unsafe { !(*vchordrqfscanshared).isconcurrent } { - heap_lockmode = pgrx::pg_sys::ShareLock as pgrx::pg_sys::LOCKMODE; - index_lockmode = pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE; - } else { - heap_lockmode = pgrx::pg_sys::ShareUpdateExclusiveLock as pgrx::pg_sys::LOCKMODE; - index_lockmode = pgrx::pg_sys::RowExclusiveLock as pgrx::pg_sys::LOCKMODE; - } - let heap = unsafe { pgrx::pg_sys::table_open((*vchordrqfscanshared).heaprelid, heap_lockmode) }; - let index = - unsafe { pgrx::pg_sys::index_open((*vchordrqfscanshared).indexrelid, index_lockmode) }; - let index_info = unsafe { pgrx::pg_sys::BuildIndexInfo(index) }; - unsafe { - (*index_info).ii_Concurrent = (*vchordrqfscanshared).isconcurrent; - } - - unsafe { - parallel_build( - index, - heap, - index_info, - tablescandesc, - vchordrqfscanshared, - None, - ); - } - - unsafe { - pgrx::pg_sys::index_close(index, index_lockmode); - pgrx::pg_sys::table_close(heap, heap_lockmode); - } -} - -unsafe fn parallel_build( - index: *mut pgrx::pg_sys::RelationData, - heap: pgrx::pg_sys::Relation, - index_info: *mut pgrx::pg_sys::IndexInfo, - tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData, - vchordrqfscanshared: *mut VchordrqfscanShared, - mut reporter: Option, -) { - #[derive(Debug, Clone)] - pub struct Heap { - heap: pgrx::pg_sys::Relation, - index: pgrx::pg_sys::Relation, - index_info: *mut pgrx::pg_sys::IndexInfo, - opfamily: Opfamily, - scan: *mut pgrx::pg_sys::TableScanDescData, - } - impl HeapRelation for Heap { - fn traverse(&self, progress: bool, callback: F) - where - F: FnMut((NonZeroU64, Vec)), - { - pub struct State<'a, F> { - pub this: &'a Heap, - pub callback: F, - } - #[pgrx::pg_guard] - unsafe extern "C" fn call( - _index: pgrx::pg_sys::Relation, - ctid: pgrx::pg_sys::ItemPointer, - values: *mut Datum, - is_null: *mut bool, - _tuple_is_alive: bool, - state: *mut core::ffi::c_void, - ) where - F: FnMut((NonZeroU64, Vec)), - { - use crate::vchordrqfscan::types::OwnedVector; - let state = unsafe { &mut *state.cast::>() }; - let opfamily = state.this.opfamily; - let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; - let pointer = unsafe { ctid_to_pointer(ctid.read()) }; - if let Some(vector) = vector { - let vector = match vector { - OwnedVector::Vecf32(x) => x, - }; - (state.callback)((pointer, vector.into_vec())); - } - } - let table_am = unsafe { &*(*self.heap).rd_tableam }; - let mut state = State { - this: self, - callback, - }; - unsafe { - table_am.index_build_range_scan.unwrap()( - self.heap, - self.index, - self.index_info, - true, - false, - progress, - 0, - pgrx::pg_sys::InvalidBlockNumber, - Some(call::), - (&mut state) as *mut State as *mut _, - self.scan, - ); - } - } - - fn opfamily(&self) -> Opfamily { - self.opfamily - } - } - - let index_relation = unsafe { PostgresRelation::new(index) }; - let scan = unsafe { pgrx::pg_sys::table_beginscan_parallel(heap, tablescandesc) }; - let opfamily = unsafe { am_options::opfamily(index) }; - let heap_relation = Heap { - heap, - index, - index_info, - opfamily, - scan, - }; - heap_relation.traverse(reporter.is_some(), |(payload, vector)| { - algorithm::insert::insert( - index_relation.clone(), - payload, - vector, - opfamily.distance_kind(), - true, - ); - unsafe { - let indtuples; - { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqfscanshared).mutex); - (*vchordrqfscanshared).indtuples += 1; - indtuples = (*vchordrqfscanshared).indtuples; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqfscanshared).mutex); - } - if let Some(reporter) = reporter.as_mut() { - reporter.tuples_done(indtuples); - } - } - }); - - unsafe { - pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqfscanshared).mutex); - (*vchordrqfscanshared).nparticipantsdone += 1; - pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqfscanshared).mutex); - pgrx::pg_sys::ConditionVariableSignal(&raw mut (*vchordrqfscanshared).workersdonecv); - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn ambuildempty(_index: pgrx::pg_sys::Relation) { - pgrx::error!("Unlogged indexes are not supported."); -} - -#[cfg(feature = "pg13")] -#[pgrx::pg_guard] -pub unsafe extern "C" fn aminsert( - index: pgrx::pg_sys::Relation, - values: *mut Datum, - is_null: *mut bool, - heap_tid: pgrx::pg_sys::ItemPointer, - _heap: pgrx::pg_sys::Relation, - _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, - _index_info: *mut pgrx::pg_sys::IndexInfo, -) -> bool { - use crate::vchordrqfscan::types::OwnedVector; - let opfamily = unsafe { am_options::opfamily(index) }; - let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; - if let Some(vector) = vector { - let vector = match vector { - OwnedVector::Vecf32(x) => x, - }; - let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); - algorithm::insert::insert( - unsafe { PostgresRelation::new(index) }, - pointer, - vector.into_vec(), - opfamily.distance_kind(), - false, - ); - } - false -} - -#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16", feature = "pg17"))] -#[pgrx::pg_guard] -pub unsafe extern "C" fn aminsert( - index: pgrx::pg_sys::Relation, - values: *mut Datum, - is_null: *mut bool, - heap_tid: pgrx::pg_sys::ItemPointer, - _heap: pgrx::pg_sys::Relation, - _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, - _index_unchanged: bool, - _index_info: *mut pgrx::pg_sys::IndexInfo, -) -> bool { - use crate::vchordrqfscan::types::OwnedVector; - let opfamily = unsafe { am_options::opfamily(index) }; - let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; - if let Some(vector) = vector { - let vector = match vector { - OwnedVector::Vecf32(x) => x, - }; - let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); - algorithm::insert::insert( - unsafe { PostgresRelation::new(index) }, - pointer, - vector.into_vec(), - opfamily.distance_kind(), - false, - ); - } - false -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn ambeginscan( - index: pgrx::pg_sys::Relation, - n_keys: std::os::raw::c_int, - n_orderbys: std::os::raw::c_int, -) -> pgrx::pg_sys::IndexScanDesc { - use pgrx::memcxt::PgMemoryContexts::CurrentMemoryContext; - - let scan = unsafe { pgrx::pg_sys::RelationGetIndexScan(index, n_keys, n_orderbys) }; - unsafe { - let scanner = am_scan::scan_make(None, None, false); - (*scan).opaque = CurrentMemoryContext.leak_and_drop_on_delete(scanner).cast(); - } - scan -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amrescan( - scan: pgrx::pg_sys::IndexScanDesc, - keys: pgrx::pg_sys::ScanKey, - _n_keys: std::os::raw::c_int, - orderbys: pgrx::pg_sys::ScanKey, - _n_orderbys: std::os::raw::c_int, -) { - unsafe { - if !keys.is_null() && (*scan).numberOfKeys > 0 { - std::ptr::copy(keys, (*scan).keyData, (*scan).numberOfKeys as _); - } - if !orderbys.is_null() && (*scan).numberOfOrderBys > 0 { - std::ptr::copy(orderbys, (*scan).orderByData, (*scan).numberOfOrderBys as _); - } - let opfamily = am_options::opfamily((*scan).indexRelation); - let (orderbys, spheres) = { - let mut orderbys = Vec::new(); - let mut spheres = Vec::new(); - if (*scan).numberOfOrderBys == 0 && (*scan).numberOfKeys == 0 { - pgrx::error!( - "vector search with no WHERE clause and no ORDER BY clause is not supported" - ); - } - for i in 0..(*scan).numberOfOrderBys { - let data = (*scan).orderByData.add(i as usize); - let value = (*data).sk_argument; - let is_null = ((*data).sk_flags & pgrx::pg_sys::SK_ISNULL as i32) != 0; - match (*data).sk_strategy { - 1 => orderbys.push(opfamily.datum_to_vector(value, is_null)), - _ => unreachable!(), - } - } - for i in 0..(*scan).numberOfKeys { - let data = (*scan).keyData.add(i as usize); - let value = (*data).sk_argument; - let is_null = ((*data).sk_flags & pgrx::pg_sys::SK_ISNULL as i32) != 0; - match (*data).sk_strategy { - 2 => spheres.push(opfamily.datum_to_sphere(value, is_null)), - _ => unreachable!(), - } - } - (orderbys, spheres) - }; - let (vector, threshold, recheck) = am_scan::scan_build(orderbys, spheres, opfamily); - let scanner = (*scan).opaque.cast::().as_mut().unwrap_unchecked(); - let scanner = std::mem::replace(scanner, am_scan::scan_make(vector, threshold, recheck)); - am_scan::scan_release(scanner); - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amgettuple( - scan: pgrx::pg_sys::IndexScanDesc, - direction: pgrx::pg_sys::ScanDirection::Type, -) -> bool { - if direction != pgrx::pg_sys::ScanDirection::ForwardScanDirection { - pgrx::error!("vector search without a forward scan direction is not supported"); - } - // https://www.postgresql.org/docs/current/index-locking.html - // If heap entries referenced physical pointers are deleted before - // they are consumed by PostgreSQL, PostgreSQL will received wrong - // physical pointers: no rows or irreverent rows are referenced. - if unsafe { (*(*scan).xs_snapshot).snapshot_type } != pgrx::pg_sys::SnapshotType::SNAPSHOT_MVCC - { - pgrx::error!("scanning with a non-MVCC-compliant snapshot is not supported"); - } - let scanner = unsafe { (*scan).opaque.cast::().as_mut().unwrap_unchecked() }; - let relation = unsafe { PostgresRelation::new((*scan).indexRelation) }; - if let Some((pointer, recheck)) = am_scan::scan_next(scanner, relation) { - let ctid = pointer_to_ctid(pointer); - unsafe { - (*scan).xs_heaptid = ctid; - (*scan).xs_recheckorderby = false; - (*scan).xs_recheck = recheck; - } - true - } else { - false - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amendscan(scan: pgrx::pg_sys::IndexScanDesc) { - unsafe { - let scanner = (*scan).opaque.cast::().as_mut().unwrap_unchecked(); - let scanner = std::mem::replace(scanner, am_scan::scan_make(None, None, false)); - am_scan::scan_release(scanner); - } -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn ambulkdelete( - info: *mut pgrx::pg_sys::IndexVacuumInfo, - stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, - callback: pgrx::pg_sys::IndexBulkDeleteCallback, - callback_state: *mut std::os::raw::c_void, -) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { - let mut stats = stats; - if stats.is_null() { - stats = unsafe { - pgrx::pg_sys::palloc0(size_of::()).cast() - }; - } - let callback = callback.unwrap(); - let callback = |p: NonZeroU64| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; - algorithm::vacuum::vacuum( - unsafe { PostgresRelation::new((*info).index) }, - || unsafe { - pgrx::pg_sys::vacuum_delay_point(); - }, - callback, - ); - stats -} - -#[pgrx::pg_guard] -pub unsafe extern "C" fn amvacuumcleanup( - _info: *mut pgrx::pg_sys::IndexVacuumInfo, - _stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, -) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { - std::ptr::null_mut() -} diff --git a/src/vchordrqfscan/index/am_options.rs b/src/vchordrqfscan/index/am_options.rs deleted file mode 100644 index be4154e..0000000 --- a/src/vchordrqfscan/index/am_options.rs +++ /dev/null @@ -1,218 +0,0 @@ -use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; -use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; -use crate::datatype::typmod::Typmod; -use crate::vchordrqfscan::types::{BorrowedVector, OwnedVector}; -use crate::vchordrqfscan::types::{DistanceKind, VectorKind}; -use crate::vchordrqfscan::types::{VchordrqfscanIndexingOptions, VectorOptions}; -use distance::Distance; -use pgrx::datum::FromDatum; -use pgrx::heap_tuple::PgHeapTuple; -use serde::Deserialize; -use std::ffi::CStr; -use std::num::NonZero; -use vector::VectorBorrowed; - -#[derive(Copy, Clone, Debug, Default)] -#[repr(C)] -pub struct Reloption { - vl_len_: i32, - pub options: i32, -} - -impl Reloption { - pub const TAB: &'static [pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { - optname: c"options".as_ptr(), - opttype: pgrx::pg_sys::relopt_type::RELOPT_TYPE_STRING, - offset: std::mem::offset_of!(Reloption, options) as i32, - }]; - unsafe fn options(&self) -> &CStr { - unsafe { - let ptr = (&raw const *self) - .cast::() - .offset(self.options as _); - CStr::from_ptr(ptr) - } - } -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum PgDistanceKind { - L2, - Dot, - Cos, -} - -impl PgDistanceKind { - pub fn to_distance(self) -> DistanceKind { - match self { - PgDistanceKind::L2 => DistanceKind::L2, - PgDistanceKind::Dot | PgDistanceKind::Cos => DistanceKind::Dot, - } - } -} - -fn convert_name_to_vd(name: &str) -> Option<(VectorKind, PgDistanceKind)> { - match name.strip_suffix("_ops") { - Some("vector_l2") => Some((VectorKind::Vecf32, PgDistanceKind::L2)), - Some("vector_ip") => Some((VectorKind::Vecf32, PgDistanceKind::Dot)), - Some("vector_cosine") => Some((VectorKind::Vecf32, PgDistanceKind::Cos)), - _ => None, - } -} - -unsafe fn convert_reloptions_to_options( - reloptions: *const pgrx::pg_sys::varlena, -) -> VchordrqfscanIndexingOptions { - #[derive(Debug, Clone, Deserialize, Default)] - #[serde(deny_unknown_fields)] - struct Parsed { - #[serde(flatten)] - rabitq: VchordrqfscanIndexingOptions, - } - let reloption = reloptions as *const Reloption; - if reloption.is_null() || unsafe { (*reloption).options == 0 } { - return Default::default(); - } - let s = unsafe { (*reloption).options() }.to_string_lossy(); - match toml::from_str::(&s) { - Ok(p) => p.rabitq, - Err(e) => pgrx::error!("failed to parse options: {}", e), - } -} - -pub unsafe fn options( - index: pgrx::pg_sys::Relation, -) -> (VectorOptions, VchordrqfscanIndexingOptions) { - let att = unsafe { &mut *(*index).rd_att }; - let atts = unsafe { att.attrs.as_slice(att.natts as _) }; - if atts.is_empty() { - pgrx::error!("indexing on no columns is not supported"); - } - if atts.len() != 1 { - pgrx::error!("multicolumn index is not supported"); - } - // get dims - let typmod = Typmod::parse_from_i32(atts[0].type_mod()).unwrap(); - let dims = if let Some(dims) = typmod.dims() { - dims.get() - } else { - pgrx::error!( - "Dimensions type modifier of a vector column is needed for building the index." - ); - }; - // get v, d - let opfamily = unsafe { opfamily(index) }; - let vector = VectorOptions { - dims, - v: opfamily.vector, - d: opfamily.distance_kind(), - }; - // get indexing, segment, optimizing - let rabitq = unsafe { convert_reloptions_to_options((*index).rd_options) }; - (vector, rabitq) -} - -#[derive(Debug, Clone, Copy)] -pub struct Opfamily { - vector: VectorKind, - pg_distance: PgDistanceKind, -} - -impl Opfamily { - pub unsafe fn datum_to_vector( - self, - datum: pgrx::pg_sys::Datum, - is_null: bool, - ) -> Option { - if is_null || datum.is_null() { - return None; - } - let vector = match self.vector { - VectorKind::Vecf32 => { - let vector = unsafe { PgvectorVectorInput::from_datum(datum, false).unwrap() }; - self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed())) - } - }; - Some(vector) - } - pub unsafe fn datum_to_sphere( - self, - datum: pgrx::pg_sys::Datum, - is_null: bool, - ) -> (Option, Option) { - if is_null || datum.is_null() { - return (None, None); - } - let tuple = unsafe { PgHeapTuple::from_composite_datum(datum) }; - let center = match self.vector { - VectorKind::Vecf32 => tuple - .get_by_index::(NonZero::new(1).unwrap()) - .unwrap() - .map(|vector| self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed()))), - }; - let radius = tuple.get_by_index::(NonZero::new(2).unwrap()).unwrap(); - (center, radius) - } - pub fn preprocess(self, vector: BorrowedVector<'_>) -> OwnedVector { - use BorrowedVector as B; - use OwnedVector as O; - match (vector, self.pg_distance) { - (B::Vecf32(x), PgDistanceKind::L2) => O::Vecf32(x.own()), - (B::Vecf32(x), PgDistanceKind::Dot) => O::Vecf32(x.own()), - (B::Vecf32(x), PgDistanceKind::Cos) => O::Vecf32(x.function_normalize()), - } - } - pub fn process(self, x: Distance) -> f32 { - match self.pg_distance { - PgDistanceKind::Cos => f32::from(x) + 1.0f32, - PgDistanceKind::L2 => f32::from(x).sqrt(), - _ => f32::from(x), - } - } - pub fn distance_kind(self) -> DistanceKind { - self.pg_distance.to_distance() - } -} - -pub unsafe fn opfamily(index: pgrx::pg_sys::Relation) -> Opfamily { - use pgrx::pg_sys::Oid; - - let proc = unsafe { pgrx::pg_sys::index_getprocid(index, 1, 1) }; - - if proc == Oid::INVALID { - pgrx::error!("support function 1 is not found"); - } - - let mut flinfo = pgrx::pg_sys::FmgrInfo::default(); - unsafe { - pgrx::pg_sys::fmgr_info(proc, &mut flinfo); - } - - let fn_addr = flinfo.fn_addr.expect("null function pointer"); - - let mut fcinfo = unsafe { std::mem::zeroed::() }; - fcinfo.flinfo = &mut flinfo; - fcinfo.fncollation = pgrx::pg_sys::DEFAULT_COLLATION_OID; - fcinfo.context = std::ptr::null_mut(); - fcinfo.resultinfo = std::ptr::null_mut(); - fcinfo.isnull = true; - fcinfo.nargs = 0; - - let result_datum = unsafe { pgrx::pg_sys::ffi::pg_guard_ffi_boundary(|| fn_addr(&mut fcinfo)) }; - - let result_option = unsafe { String::from_datum(result_datum, fcinfo.isnull) }; - - let result_string = result_option.expect("null string"); - - let (vector, pg_distance) = convert_name_to_vd(&result_string).unwrap(); - - unsafe { - pgrx::pg_sys::pfree(result_datum.cast_mut_ptr()); - } - - Opfamily { - vector, - pg_distance, - } -} diff --git a/src/vchordrqfscan/index/am_scan.rs b/src/vchordrqfscan/index/am_scan.rs deleted file mode 100644 index da049ab..0000000 --- a/src/vchordrqfscan/index/am_scan.rs +++ /dev/null @@ -1,125 +0,0 @@ -use super::am_options::Opfamily; -use crate::postgres::PostgresRelation; -use crate::vchordrqfscan::algorithm::scan::scan; -use crate::vchordrqfscan::gucs::executing::epsilon; -use crate::vchordrqfscan::gucs::executing::max_scan_tuples; -use crate::vchordrqfscan::gucs::executing::probes; -use crate::vchordrqfscan::types::OwnedVector; -use distance::Distance; -use std::num::NonZeroU64; - -pub enum Scanner { - Initial { - vector: Option<(OwnedVector, Opfamily)>, - threshold: Option, - recheck: bool, - }, - Vbase { - vbase: Box>, - threshold: Option, - recheck: bool, - opfamily: Opfamily, - }, - Empty {}, -} - -pub fn scan_build( - orderbys: Vec>, - spheres: Vec<(Option, Option)>, - opfamily: Opfamily, -) -> (Option<(OwnedVector, Opfamily)>, Option, bool) { - let mut pair = None; - let mut threshold = None; - let mut recheck = false; - for orderby_vector in orderbys { - if pair.is_none() { - pair = orderby_vector; - } else if orderby_vector.is_some() { - pgrx::error!("vector search with multiple vectors is not supported"); - } - } - for (sphere_vector, sphere_threshold) in spheres { - if pair.is_none() { - pair = sphere_vector; - threshold = sphere_threshold; - } else { - recheck = true; - break; - } - } - (pair.map(|x| (x, opfamily)), threshold, recheck) -} - -pub fn scan_make( - vector: Option<(OwnedVector, Opfamily)>, - threshold: Option, - recheck: bool, -) -> Scanner { - Scanner::Initial { - vector, - threshold, - recheck, - } -} - -pub fn scan_next(scanner: &mut Scanner, relation: PostgresRelation) -> Option<(NonZeroU64, bool)> { - if let Scanner::Initial { - vector, - threshold, - recheck, - } = scanner - { - if let Some((vector, opfamily)) = vector.as_ref() { - let vbase = scan( - relation, - match vector { - OwnedVector::Vecf32(x) => x.slice().to_vec(), - }, - opfamily.distance_kind(), - probes(), - epsilon(), - ); - *scanner = Scanner::Vbase { - vbase: if let Some(max_scan_tuples) = max_scan_tuples() { - Box::new(vbase.take(max_scan_tuples as usize)) - } else { - Box::new(vbase) - }, - threshold: *threshold, - recheck: *recheck, - opfamily: *opfamily, - }; - } else { - *scanner = Scanner::Empty {}; - } - } - match scanner { - Scanner::Initial { .. } => unreachable!(), - Scanner::Vbase { - vbase, - threshold, - recheck, - opfamily, - } => match ( - vbase.next().map(|(d, p)| (opfamily.process(d), p)), - threshold, - ) { - (Some((_, ptr)), None) => Some((ptr, *recheck)), - (Some((distance, ptr)), Some(t)) if distance < *t => Some((ptr, *recheck)), - _ => { - let scanner = std::mem::replace(scanner, Scanner::Empty {}); - scan_release(scanner); - None - } - }, - Scanner::Empty {} => None, - } -} - -pub fn scan_release(scanner: Scanner) { - match scanner { - Scanner::Initial { .. } => {} - Scanner::Vbase { .. } => {} - Scanner::Empty {} => {} - } -} diff --git a/src/vchordrqfscan/index/functions.rs b/src/vchordrqfscan/index/functions.rs deleted file mode 100644 index 27bd9ac..0000000 --- a/src/vchordrqfscan/index/functions.rs +++ /dev/null @@ -1,26 +0,0 @@ -use crate::postgres::PostgresRelation; -use crate::vchordrqfscan::algorithm::prewarm::prewarm; -use pgrx::pg_sys::Oid; -use pgrx_catalog::{PgAm, PgClass}; - -#[pgrx::pg_extern(sql = "")] -fn _vchordrqfscan_prewarm(indexrelid: Oid, height: i32) -> String { - let pg_am = PgAm::search_amname(c"vchordrqfscan").unwrap(); - let Some(pg_am) = pg_am.get() else { - pgrx::error!("vchord is not installed"); - }; - let pg_class = PgClass::search_reloid(indexrelid).unwrap(); - let Some(pg_class) = pg_class.get() else { - pgrx::error!("there is no such index"); - }; - if pg_class.relam() != pg_am.oid() { - pgrx::error!("{:?} is not a vchordrqfscan index", pg_class.relname()); - } - let index = unsafe { pgrx::pg_sys::index_open(indexrelid, pgrx::pg_sys::ShareLock as _) }; - let relation = unsafe { PostgresRelation::new(index) }; - let message = prewarm(relation, height); - unsafe { - pgrx::pg_sys::index_close(index, pgrx::pg_sys::ShareLock as _); - } - message -} diff --git a/src/vchordrqfscan/index/mod.rs b/src/vchordrqfscan/index/mod.rs deleted file mode 100644 index 5203e4f..0000000 --- a/src/vchordrqfscan/index/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub mod am; -pub mod am_options; -pub mod am_scan; -pub mod functions; -pub mod opclass; -pub mod utils; - -pub unsafe fn init() { - unsafe { - am::init(); - } -} diff --git a/src/vchordrqfscan/index/opclass.rs b/src/vchordrqfscan/index/opclass.rs deleted file mode 100644 index d095b9a..0000000 --- a/src/vchordrqfscan/index/opclass.rs +++ /dev/null @@ -1,14 +0,0 @@ -#[pgrx::pg_extern(immutable, strict, parallel_safe)] -fn _vchordrqfscan_support_vector_l2_ops() -> String { - "vector_l2_ops".to_string() -} - -#[pgrx::pg_extern(immutable, strict, parallel_safe)] -fn _vchordrqfscan_support_vector_ip_ops() -> String { - "vector_ip_ops".to_string() -} - -#[pgrx::pg_extern(immutable, strict, parallel_safe)] -fn _vchordrqfscan_support_vector_cosine_ops() -> String { - "vector_cosine_ops".to_string() -} diff --git a/src/vchordrqfscan/index/utils.rs b/src/vchordrqfscan/index/utils.rs deleted file mode 100644 index 726a597..0000000 --- a/src/vchordrqfscan/index/utils.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::num::NonZeroU64; - -pub fn pointer_to_ctid(pointer: NonZeroU64) -> pgrx::pg_sys::ItemPointerData { - let value = pointer.get(); - pgrx::pg_sys::ItemPointerData { - ip_blkid: pgrx::pg_sys::BlockIdData { - bi_hi: ((value >> 32) & 0xffff) as u16, - bi_lo: ((value >> 16) & 0xffff) as u16, - }, - ip_posid: (value & 0xffff) as u16, - } -} - -pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> NonZeroU64 { - let mut value = 0; - value |= (ctid.ip_blkid.bi_hi as u64) << 32; - value |= (ctid.ip_blkid.bi_lo as u64) << 16; - value |= ctid.ip_posid as u64; - NonZeroU64::new(value).expect("invalid pointer") -} diff --git a/src/vchordrqfscan/mod.rs b/src/vchordrqfscan/mod.rs deleted file mode 100644 index c2ae945..0000000 --- a/src/vchordrqfscan/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod algorithm; -mod gucs; -mod index; -mod types; - -pub unsafe fn init() { - unsafe { - index::init(); - gucs::init(); - } -} diff --git a/src/vchordrqfscan/types.rs b/src/vchordrqfscan/types.rs deleted file mode 100644 index 91ca769..0000000 --- a/src/vchordrqfscan/types.rs +++ /dev/null @@ -1,153 +0,0 @@ -use distance::Distance; -use serde::{Deserialize, Serialize}; -use validator::{Validate, ValidationError, ValidationErrors}; -use vector::vect::{VectBorrowed, VectOwned}; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct VchordrqfscanInternalBuildOptions { - #[serde(default = "VchordrqfscanInternalBuildOptions::default_lists")] - #[validate(length(min = 1, max = 8), custom(function = VchordrqfscanInternalBuildOptions::validate_lists))] - pub lists: Vec, - #[serde(default = "VchordrqfscanInternalBuildOptions::default_spherical_centroids")] - pub spherical_centroids: bool, - #[serde(default = "VchordrqfscanInternalBuildOptions::default_sampling_factor")] - #[validate(range(min = 1, max = 1024))] - pub sampling_factor: u32, - #[serde(default = "VchordrqfscanInternalBuildOptions::default_build_threads")] - #[validate(range(min = 1, max = 255))] - pub build_threads: u16, -} - -impl VchordrqfscanInternalBuildOptions { - fn default_lists() -> Vec { - vec![1000] - } - fn validate_lists(lists: &[u32]) -> Result<(), ValidationError> { - if !lists.is_sorted() { - return Err(ValidationError::new("`lists` should be in ascending order")); - } - if !lists.iter().all(|x| (1..=1 << 24).contains(x)) { - return Err(ValidationError::new("list is too long or too short")); - } - Ok(()) - } - fn default_spherical_centroids() -> bool { - false - } - fn default_sampling_factor() -> u32 { - 256 - } - fn default_build_threads() -> u16 { - 1 - } -} - -impl Default for VchordrqfscanInternalBuildOptions { - fn default() -> Self { - Self { - lists: Self::default_lists(), - spherical_centroids: Self::default_spherical_centroids(), - sampling_factor: Self::default_sampling_factor(), - build_threads: Self::default_build_threads(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct VchordrqfscanExternalBuildOptions { - pub table: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -#[serde(rename_all = "snake_case")] -pub enum VchordrqfscanBuildOptions { - Internal(VchordrqfscanInternalBuildOptions), - External(VchordrqfscanExternalBuildOptions), -} - -impl Default for VchordrqfscanBuildOptions { - fn default() -> Self { - Self::Internal(Default::default()) - } -} - -impl Validate for VchordrqfscanBuildOptions { - fn validate(&self) -> Result<(), ValidationErrors> { - use VchordrqfscanBuildOptions::*; - match self { - Internal(internal_build) => internal_build.validate(), - External(external_build) => external_build.validate(), - } - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct VchordrqfscanIndexingOptions { - #[serde(default = "VchordrqfscanIndexingOptions::default_residual_quantization")] - pub residual_quantization: bool, - pub build: VchordrqfscanBuildOptions, -} - -impl VchordrqfscanIndexingOptions { - fn default_residual_quantization() -> bool { - false - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum OwnedVector { - Vecf32(VectOwned), -} - -#[derive(Debug, Clone, Copy)] -pub enum BorrowedVector<'a> { - Vecf32(VectBorrowed<'a, f32>), -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum DistanceKind { - L2, - Dot, -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum VectorKind { - Vecf32, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -#[validate(schema(function = "Self::validate_self"))] -pub struct VectorOptions { - #[validate(range(min = 1, max = 1_048_575))] - #[serde(rename = "dimensions")] - pub dims: u32, - #[serde(rename = "vector")] - pub v: VectorKind, - #[serde(rename = "distance")] - pub d: DistanceKind, -} - -impl VectorOptions { - pub fn validate_self(&self) -> Result<(), ValidationError> { - match (self.v, self.d, self.dims) { - (VectorKind::Vecf32, DistanceKind::L2, 1..65536) => Ok(()), - (VectorKind::Vecf32, DistanceKind::Dot, 1..65536) => Ok(()), - _ => Err(ValidationError::new("not valid vector options")), - } - } -} - -pub fn distance(d: DistanceKind, lhs: &[f32], rhs: &[f32]) -> Distance { - use simd::Floating; - match d { - DistanceKind::L2 => Distance::from_f32(f32::reduce_sum_of_d2(lhs, rhs)), - DistanceKind::Dot => Distance::from_f32(-f32::reduce_sum_of_xy(lhs, rhs)), - } -}