From 75cf0c0d3a000eaa9ed9f30a710ce47a3e34bda3 Mon Sep 17 00:00:00 2001 From: Najib <32004868+nishaq503@users.noreply.github.com> Date: Sun, 19 Jan 2025 14:54:05 -0500 Subject: [PATCH] Added MSA and separated Metric from Dataset (#226) * feat: added msa and separated metric form dataset * docs: updated main README * fmt: python formatting * fix: corrected overlapp_with method for SquishyBall * feat: added min and max methods for Number trait * wip: disconnecting msa and pancakes --- .gitignore | 1 + Cargo.toml | 43 +- Earthfile | 8 +- README.md | 24 +- benches/cakes/Cargo.toml | 16 + benches/cakes/src/README.md | 40 + benches/cakes/src/data_gen.rs | 439 ++++++++++ benches/cakes/src/main.rs | 310 ++++++++ benches/cakes/src/metric/cosine.rs | 147 ++++ benches/cakes/src/metric/dtw.rs | 90 +++ benches/cakes/src/metric/euclidean.rs | 147 ++++ benches/cakes/src/metric/hamming.rs | 82 ++ benches/cakes/src/metric/levenshtein.rs | 83 ++ benches/cakes/src/metric/mod.rs | 147 ++++ benches/cakes/src/search.rs | 221 ++++++ benches/cakes/src/trees.rs | 157 ++++ benches/cakes/src/workflow.rs | 284 +++++++ benches/py-cakes/.python-version | 1 + benches/py-cakes/README.md | 29 + benches/py-cakes/pyproject.toml | 34 + benches/py-cakes/src/py_cakes/__init__.py | 12 + benches/py-cakes/src/py_cakes/__main__.py | 104 +++ .../src/py_cakes/competitors/annoy.csv | 32 + .../src/py_cakes/competitors/faiss-ivf.csv | 34 + .../src/py_cakes/competitors/hnsw.csv | 26 + .../py-cakes/src/py_cakes/summarize_rust.py | 657 +++++++++++++++ benches/py-cakes/src/py_cakes/utils.py | 10 + benches/utils/Cargo.toml | 19 + benches/utils/src/ann_benchmarks/mod.rs | 57 ++ benches/utils/src/ann_benchmarks/reader.rs | 155 ++++ benches/utils/src/fasta/mod.rs | 115 +++ benches/utils/src/lib.rs | 222 ++++++ benches/utils/src/metrics/dtw.rs | 121 +++ benches/utils/src/metrics/jaccard.rs | 38 + benches/utils/src/metrics/levenshtein.rs | 39 + benches/utils/src/metrics/mod.rs | 9 + benches/utils/src/radio_ml/mod.rs | 7 + benches/utils/src/radio_ml/modulation_mode.rs | 101 +++ benches/utils/src/radio_ml/reader.rs | 112 +++ benches/utils/src/reports/cakes.rs | 299 +++++++ benches/utils/src/reports/mod.rs | 5 + benches/utils/src/types.rs | 64 ++ crates/abd-clam/.bumpversion.cfg | 2 +- crates/abd-clam/Cargo.toml | 50 +- crates/abd-clam/README.md | 229 +++--- crates/abd-clam/VERSION | 2 +- crates/abd-clam/benches/ann_benchmarks.rs | 256 ++++-- crates/abd-clam/benches/genomic_search.rs | 106 +-- .../benches/utils/compare_permuted.rs | 386 +++++---- crates/abd-clam/benches/utils/mod.rs | 61 ++ crates/abd-clam/benches/vector_search.rs | 144 ++-- crates/abd-clam/src/cakes/cluster/mod.rs | 20 +- .../abd-clam/src/cakes/cluster/offset_ball.rs | 329 -------- .../src/cakes/cluster/permuted_ball.rs | 333 ++++++++ .../abd-clam/src/cakes/cluster/searchable.rs | 49 -- crates/abd-clam/src/cakes/codec/codec_data.rs | 525 ------------ crates/abd-clam/src/cakes/codec/mod.rs | 265 ------ .../abd-clam/src/cakes/codec/squishy_ball.rs | 490 ------------ crates/abd-clam/src/cakes/dataset/hinted.rs | 127 +++ crates/abd-clam/src/cakes/dataset/mod.rs | 67 +- .../abd-clam/src/cakes/dataset/searchable.rs | 54 +- .../abd-clam/src/cakes/dataset/shardable.rs | 103 --- crates/abd-clam/src/cakes/mod.rs | 272 +------ .../src/cakes/search/knn_breadth_first.rs | 263 +++--- .../src/cakes/search/knn_depth_first.rs | 298 +++---- .../abd-clam/src/cakes/search/knn_hinted.rs | 114 +++ .../abd-clam/src/cakes/search/knn_linear.rs | 68 ++ .../src/cakes/search/knn_repeated_rnn.rs | 268 +++---- crates/abd-clam/src/cakes/search/mod.rs | 308 +++---- .../src/cakes/search/rnn_clustered.rs | 322 ++++---- .../abd-clam/src/cakes/search/rnn_linear.rs | 48 ++ crates/abd-clam/src/chaoda/cluster/vertex.rs | 261 +++--- .../src/chaoda/graph/adjacency_list.rs | 75 +- crates/abd-clam/src/chaoda/graph/component.rs | 28 +- crates/abd-clam/src/chaoda/graph/mod.rs | 133 +++- crates/abd-clam/src/chaoda/graph/node.rs | 24 +- .../src/chaoda/inference/combination.rs | 81 +- .../abd-clam/src/chaoda/inference/meta_ml.rs | 3 +- crates/abd-clam/src/chaoda/inference/mod.rs | 102 ++- .../src/chaoda/inference/trained_smc.rs | 369 +++++++++ crates/abd-clam/src/chaoda/mod.rs | 22 +- .../src/chaoda/training/algorithms/cc.rs | 10 +- .../src/chaoda/training/algorithms/gn.rs | 10 +- .../src/chaoda/training/algorithms/mod.rs | 44 +- .../src/chaoda/training/algorithms/pc.rs | 13 +- .../src/chaoda/training/algorithms/sc.rs | 13 +- .../src/chaoda/training/algorithms/sp.rs | 10 +- .../src/chaoda/training/algorithms/vd.rs | 10 +- .../src/chaoda/training/combination.rs | 26 +- crates/abd-clam/src/chaoda/training/mod.rs | 161 ++-- .../src/chaoda/training/trainable_smc.rs | 611 ++++++++++++++ crates/abd-clam/src/core/cluster/adapter.rs | 309 +++---- .../src/core/cluster/balanced_ball.rs | 493 ++++++------ crates/abd-clam/src/core/cluster/ball.rs | 535 ++++--------- crates/abd-clam/src/core/cluster/csv.rs | 146 ---- crates/abd-clam/src/core/cluster/io.rs | 147 ++++ crates/abd-clam/src/core/cluster/lfd.rs | 8 +- crates/abd-clam/src/core/cluster/mod.rs | 261 +++--- crates/abd-clam/src/core/cluster/partition.rs | 494 ++++++++---- .../src/core/dataset/associates_metadata.rs | 53 ++ crates/abd-clam/src/core/dataset/flat_vec.rs | 751 ++++++++---------- crates/abd-clam/src/core/dataset/io.rs | 55 ++ crates/abd-clam/src/core/dataset/metric.rs | 224 ------ .../abd-clam/src/core/dataset/metric_space.rs | 346 -------- crates/abd-clam/src/core/dataset/mod.rs | 503 ++++++------ .../abd-clam/src/core/dataset/permutable.rs | 8 +- .../abd-clam/src/core/dataset/sized_heap.rs | 154 ++++ .../src/core/metric/absolute_difference.rs | 41 + crates/abd-clam/src/core/metric/cosine.rs | 40 + crates/abd-clam/src/core/metric/euclidean.rs | 40 + crates/abd-clam/src/core/metric/hypotenuse.rs | 43 + .../abd-clam/src/core/metric/levenshtein.rs | 40 + crates/abd-clam/src/core/metric/manhattan.rs | 43 + crates/abd-clam/src/core/metric/mod.rs | 227 ++++++ crates/abd-clam/src/core/mod.rs | 8 +- crates/abd-clam/src/core/tree/mod.rs | 259 ++++++ crates/abd-clam/src/lib.rs | 15 +- crates/abd-clam/src/mbed/physics/mass.rs | 11 +- crates/abd-clam/src/mbed/physics/spring.rs | 8 +- crates/abd-clam/src/mbed/physics/system.rs | 48 +- .../abd-clam/src/msa/aligner/cost_matrix.rs | 357 +++++++++ crates/abd-clam/src/msa/aligner/mod.rs | 252 ++++++ crates/abd-clam/src/msa/aligner/ops.rs | 63 ++ crates/abd-clam/src/msa/dataset/columnar.rs | 425 ++++++++++ crates/abd-clam/src/msa/dataset/mod.rs | 7 + crates/abd-clam/src/msa/dataset/msa.rs | 629 +++++++++++++++ crates/abd-clam/src/msa/mod.rs | 12 + crates/abd-clam/src/msa/sequence.rs | 116 +++ crates/abd-clam/src/pancakes/cluster/mod.rs | 5 + .../src/pancakes/cluster/squishy_ball.rs | 451 +++++++++++ .../src/pancakes/dataset/codec_data.rs | 531 +++++++++++++ .../codec => pancakes/dataset}/compression.rs | 47 +- .../dataset}/decompression.rs | 40 +- crates/abd-clam/src/pancakes/dataset/mod.rs | 112 +++ crates/abd-clam/src/pancakes/mod.rs | 8 + crates/abd-clam/src/pancakes/sequence.rs | 178 +++++ crates/abd-clam/src/utils.rs | 306 ++----- crates/abd-clam/tests/ball.rs | 229 ++++++ crates/abd-clam/tests/cakes.rs | 261 ++++++ crates/abd-clam/tests/chaoda.rs | 19 + crates/abd-clam/tests/common/cluster.rs | 47 ++ crates/abd-clam/tests/common/data_gen.rs | 43 + crates/abd-clam/tests/common/mod.rs | 7 + crates/abd-clam/tests/common/search.rs | 104 +++ crates/abd-clam/tests/common/sequence.rs | 1 + crates/abd-clam/tests/flat_vec.rs | 159 ++++ crates/abd-clam/tests/needleman_wunsch.rs | 104 +++ crates/abd-clam/tests/pancakes.rs | 142 ++++ crates/abd-clam/tests/tree.rs | 118 +++ crates/abd-clam/tests/utils.rs | 180 +++++ crates/distances/Cargo.toml | 1 + crates/distances/src/number/_bool.rs | 4 + crates/distances/src/number/_number.rs | 51 ++ crates/distances/src/number/_variants.rs | 26 +- crates/distances/src/strings/mod.rs | 13 +- .../src/strings/needleman_wunsch/helpers.rs | 12 +- .../src/strings/needleman_wunsch/mod.rs | 4 +- crates/distances/src/vectors/lp_norms.rs | 6 +- crates/distances/tests/test_edits.rs | 6 +- crates/distances/tests/test_sets.rs | 12 +- crates/distances/tests/test_vectors_f32.rs | 14 +- crates/distances/tests/test_vectors_u32.rs | 10 +- crates/results/cakes/Cargo.toml | 8 +- crates/results/cakes/src/data/mod.rs | 6 +- crates/results/cakes/src/data/raw/fasta.rs | 64 +- crates/results/cakes/src/data/raw/mod.rs | 137 ++-- crates/results/cakes/src/data/tree/aligned.rs | 224 +++--- crates/results/cakes/src/data/tree/ann_set.rs | 250 +++--- .../data/tree/instances/aligned_sequence.rs | 127 ++- .../src/data/tree/instances/member_set.rs | 135 ++-- .../cakes/src/data/tree/instances/mod.rs | 4 +- .../data/tree/instances/unaligned_sequence.rs | 69 +- crates/results/cakes/src/data/tree/mod.rs | 65 +- .../results/cakes/src/data/tree/unaligned.rs | 221 +++--- crates/results/cakes/src/lib.rs | 19 + crates/results/cakes/src/main.rs | 7 + crates/results/cakes/src/utils.rs | 6 + crates/results/chaoda/Cargo.toml | 4 +- crates/results/chaoda/src/data/mod.rs | 21 +- crates/results/chaoda/src/main.rs | 106 +-- crates/results/msa/Cargo.toml | 15 + crates/results/msa/README.md | 32 + crates/results/msa/src/data/mod.rs | 29 + crates/results/msa/src/data/raw.rs | 164 ++++ crates/results/msa/src/main.rs | 210 +++++ crates/results/msa/src/steps.rs | 179 +++++ crates/results/rite-solutions/Cargo.toml | 2 +- crates/results/rite-solutions/src/data/mod.rs | 18 +- .../src/data/neighborhood_aware.rs | 82 +- .../rite-solutions/src/data/vec_metric.rs | 17 +- crates/results/rite-solutions/src/main.rs | 23 +- pypi/distances/.bumpversion.cfg | 2 +- pypi/distances/Cargo.toml | 2 +- pypi/distances/README.md | 2 +- pypi/distances/VERSION | 2 +- pypi/distances/pyproject.toml | 4 +- .../python/abd_distances/__init__.py | 2 +- pypi/distances/src/lib.rs | 8 +- pypi/distances/src/simd.rs | 21 +- pypi/distances/src/strings.rs | 13 +- pypi/distances/src/utils.rs | 34 +- pypi/distances/src/vectors.rs | 39 +- pypi/results/cakes/.github/workflows/CI.yml | 120 --- pypi/results/cakes/.gitignore | 72 -- pypi/results/cakes/Cargo.toml | 12 - pypi/results/cakes/pyproject.toml | 32 - .../results/cakes/python/py_cakes/__init__.py | 6 - .../results/cakes/python/py_cakes/__main__.py | 52 -- pypi/results/cakes/python/py_cakes/tables.py | 96 --- .../cakes/python/py_cakes/wrangling_logs.py | 72 -- pypi/results/cakes/python/tests/__init__.py | 1 - pypi/results/cakes/python/tests/test_all.py | 5 - pypi/results/cakes/src/lib.rs | 14 - pyproject.toml | 4 +- requirements-dev.lock | 70 +- requirements.lock | 44 +- ruff.toml | 4 +- 217 files changed, 17384 insertions(+), 8121 deletions(-) create mode 100644 benches/cakes/Cargo.toml create mode 100644 benches/cakes/src/README.md create mode 100644 benches/cakes/src/data_gen.rs create mode 100644 benches/cakes/src/main.rs create mode 100644 benches/cakes/src/metric/cosine.rs create mode 100644 benches/cakes/src/metric/dtw.rs create mode 100644 benches/cakes/src/metric/euclidean.rs create mode 100644 benches/cakes/src/metric/hamming.rs create mode 100644 benches/cakes/src/metric/levenshtein.rs create mode 100644 benches/cakes/src/metric/mod.rs create mode 100644 benches/cakes/src/search.rs create mode 100644 benches/cakes/src/trees.rs create mode 100644 benches/cakes/src/workflow.rs create mode 100644 benches/py-cakes/.python-version create mode 100644 benches/py-cakes/README.md create mode 100644 benches/py-cakes/pyproject.toml create mode 100644 benches/py-cakes/src/py_cakes/__init__.py create mode 100644 benches/py-cakes/src/py_cakes/__main__.py create mode 100644 benches/py-cakes/src/py_cakes/competitors/annoy.csv create mode 100644 benches/py-cakes/src/py_cakes/competitors/faiss-ivf.csv create mode 100644 benches/py-cakes/src/py_cakes/competitors/hnsw.csv create mode 100644 benches/py-cakes/src/py_cakes/summarize_rust.py create mode 100644 benches/py-cakes/src/py_cakes/utils.py create mode 100644 benches/utils/Cargo.toml create mode 100644 benches/utils/src/ann_benchmarks/mod.rs create mode 100644 benches/utils/src/ann_benchmarks/reader.rs create mode 100644 benches/utils/src/fasta/mod.rs create mode 100644 benches/utils/src/lib.rs create mode 100644 benches/utils/src/metrics/dtw.rs create mode 100644 benches/utils/src/metrics/jaccard.rs create mode 100644 benches/utils/src/metrics/levenshtein.rs create mode 100644 benches/utils/src/metrics/mod.rs create mode 100644 benches/utils/src/radio_ml/mod.rs create mode 100644 benches/utils/src/radio_ml/modulation_mode.rs create mode 100644 benches/utils/src/radio_ml/reader.rs create mode 100644 benches/utils/src/reports/cakes.rs create mode 100644 benches/utils/src/reports/mod.rs create mode 100644 benches/utils/src/types.rs delete mode 100644 crates/abd-clam/src/cakes/cluster/offset_ball.rs create mode 100644 crates/abd-clam/src/cakes/cluster/permuted_ball.rs delete mode 100644 crates/abd-clam/src/cakes/cluster/searchable.rs delete mode 100644 crates/abd-clam/src/cakes/codec/codec_data.rs delete mode 100644 crates/abd-clam/src/cakes/codec/mod.rs delete mode 100644 crates/abd-clam/src/cakes/codec/squishy_ball.rs create mode 100644 crates/abd-clam/src/cakes/dataset/hinted.rs delete mode 100644 crates/abd-clam/src/cakes/dataset/shardable.rs create mode 100644 crates/abd-clam/src/cakes/search/knn_hinted.rs create mode 100644 crates/abd-clam/src/cakes/search/knn_linear.rs create mode 100644 crates/abd-clam/src/cakes/search/rnn_linear.rs create mode 100644 crates/abd-clam/src/chaoda/inference/trained_smc.rs create mode 100644 crates/abd-clam/src/chaoda/training/trainable_smc.rs delete mode 100644 crates/abd-clam/src/core/cluster/csv.rs create mode 100644 crates/abd-clam/src/core/cluster/io.rs create mode 100644 crates/abd-clam/src/core/dataset/associates_metadata.rs create mode 100644 crates/abd-clam/src/core/dataset/io.rs delete mode 100644 crates/abd-clam/src/core/dataset/metric.rs delete mode 100644 crates/abd-clam/src/core/dataset/metric_space.rs create mode 100644 crates/abd-clam/src/core/dataset/sized_heap.rs create mode 100644 crates/abd-clam/src/core/metric/absolute_difference.rs create mode 100644 crates/abd-clam/src/core/metric/cosine.rs create mode 100644 crates/abd-clam/src/core/metric/euclidean.rs create mode 100644 crates/abd-clam/src/core/metric/hypotenuse.rs create mode 100644 crates/abd-clam/src/core/metric/levenshtein.rs create mode 100644 crates/abd-clam/src/core/metric/manhattan.rs create mode 100644 crates/abd-clam/src/core/metric/mod.rs create mode 100644 crates/abd-clam/src/core/tree/mod.rs create mode 100644 crates/abd-clam/src/msa/aligner/cost_matrix.rs create mode 100644 crates/abd-clam/src/msa/aligner/mod.rs create mode 100644 crates/abd-clam/src/msa/aligner/ops.rs create mode 100644 crates/abd-clam/src/msa/dataset/columnar.rs create mode 100644 crates/abd-clam/src/msa/dataset/mod.rs create mode 100644 crates/abd-clam/src/msa/dataset/msa.rs create mode 100644 crates/abd-clam/src/msa/mod.rs create mode 100644 crates/abd-clam/src/msa/sequence.rs create mode 100644 crates/abd-clam/src/pancakes/cluster/mod.rs create mode 100644 crates/abd-clam/src/pancakes/cluster/squishy_ball.rs create mode 100644 crates/abd-clam/src/pancakes/dataset/codec_data.rs rename crates/abd-clam/src/{cakes/codec => pancakes/dataset}/compression.rs (61%) rename crates/abd-clam/src/{cakes/codec => pancakes/dataset}/decompression.rs (51%) create mode 100644 crates/abd-clam/src/pancakes/dataset/mod.rs create mode 100644 crates/abd-clam/src/pancakes/mod.rs create mode 100644 crates/abd-clam/src/pancakes/sequence.rs create mode 100644 crates/abd-clam/tests/ball.rs create mode 100644 crates/abd-clam/tests/cakes.rs create mode 100644 crates/abd-clam/tests/chaoda.rs create mode 100644 crates/abd-clam/tests/common/cluster.rs create mode 100644 crates/abd-clam/tests/common/data_gen.rs create mode 100644 crates/abd-clam/tests/common/mod.rs create mode 100644 crates/abd-clam/tests/common/search.rs create mode 100644 crates/abd-clam/tests/common/sequence.rs create mode 100644 crates/abd-clam/tests/flat_vec.rs create mode 100644 crates/abd-clam/tests/needleman_wunsch.rs create mode 100644 crates/abd-clam/tests/pancakes.rs create mode 100644 crates/abd-clam/tests/tree.rs create mode 100644 crates/abd-clam/tests/utils.rs create mode 100644 crates/results/cakes/src/lib.rs create mode 100644 crates/results/msa/Cargo.toml create mode 100644 crates/results/msa/README.md create mode 100644 crates/results/msa/src/data/mod.rs create mode 100644 crates/results/msa/src/data/raw.rs create mode 100644 crates/results/msa/src/main.rs create mode 100644 crates/results/msa/src/steps.rs delete mode 100644 pypi/results/cakes/.github/workflows/CI.yml delete mode 100644 pypi/results/cakes/.gitignore delete mode 100644 pypi/results/cakes/Cargo.toml delete mode 100644 pypi/results/cakes/pyproject.toml delete mode 100644 pypi/results/cakes/python/py_cakes/__init__.py delete mode 100644 pypi/results/cakes/python/py_cakes/__main__.py delete mode 100644 pypi/results/cakes/python/py_cakes/tables.py delete mode 100644 pypi/results/cakes/python/py_cakes/wrangling_logs.py delete mode 100644 pypi/results/cakes/python/tests/__init__.py delete mode 100644 pypi/results/cakes/python/tests/test_all.py delete mode 100644 pypi/results/cakes/src/lib.rs diff --git a/.gitignore b/.gitignore index 501a0b246..b5b845f67 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ logs .tmp-earthly-out .vscode/settings.json .ruff_cache +*.svg ################################################################################ # Rust. Generated by Cargo # diff --git a/Cargo.toml b/Cargo.toml index 519673efb..28ea05c94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,34 +3,48 @@ members = [ "crates/abd-clam", "crates/distances", "crates/symagen", - "crates/results/chaoda", "crates/results/cakes", + "crates/results/chaoda", "crates/results/rite-solutions", + "crates/results/msa", "pypi/distances", - "pypi/results/cakes", + "benches/utils", + "benches/cakes", ] resolver = "2" [workspace.dependencies] -abd-clam = { version = "0.31.0", path = "crates/abd-clam" } +abd-clam = { version = "0.32.0", path = "crates/abd-clam" } distances = { version = "1.8.0", path = "crates/distances" } symagen = { version = "0.5.0", path = "crates/symagen" } rayon = "1.8" rand = "0.8" serde = { version = "1.0", features = ["derive"] } -bincode = "1.3" -ftlog = "0.2.0" +# bitcode = { version = "0.5" } +bitcode = { git = "https://github.com/nishaq503/bitcode.git", rev = "1c393ad97288555fc3fe41b292b2bd826486a992" } libm = "0.2" -ndarray = { version = "0.15.6", features = ["rayon", "approx"] } -ndarray-npy = "0.8.0" -ordered-float = "4.2" -flate2 = { version = "1.0", features = ["zlib"] } +ndarray = { version = "0.16", features = ["rayon", "approx"] } +ndarray-npy = "0.9" +csv = { version = "1.3.0" } +flate2 = { version = "1.0" } +# For GCD and LCM calculations. +num-integer = "0.1" +# For reading fasta files. +bio = "2.0" +# For a faster implementation of Levenshtein distance. +stringzilla = "3.10" +# For CLI tools +clap = { version = "4.5", features = ["derive"] } +# For low-latency logging from multiple threads. +ftlog = { version = "0.2" } +# For reading and writing HDF5 files. +hdf5 = { package = "hdf5-metno", version = "0.9.0" } -# Python wrapper dependencies -numpy = "0.20.0" -pyo3 = { version = "0.20", features = ["extension-module", "abi3-py39"] } -pyo3-ffi = { version = "0.20", features = ["extension-module", "abi3-py39"] } +# For Python Wrappers +numpy = "0.23" +pyo3 = { version = "0.23", features = ["extension-module", "abi3-py39"] } +pyo3-ffi = { version = "0.23", features = ["extension-module", "abi3-py39"] } [profile.test] opt-level = 3 @@ -38,10 +52,13 @@ debug = true overflow-checks = true [profile.release] +# debug = true +opt-level = 3 strip = true lto = true codegen-units = 1 [profile.bench] +opt-level = 3 debug = true overflow-checks = true diff --git a/Earthfile b/Earthfile index 62913e215..287f7ee80 100644 --- a/Earthfile +++ b/Earthfile @@ -31,7 +31,7 @@ ENV PATH="${RYE_HOME}/shims:${PATH}" # This target prepares the recipe.json file for the build stage. chef-prepare: - COPY --dir crates pypi . + COPY --dir benches crates pypi . COPY Cargo.toml . RUN cargo chef prepare SAVE ARTIFACT recipe.json @@ -42,6 +42,7 @@ chef-cook: RUN cargo chef cook --release COPY Cargo.toml pyproject.toml requirements.lock requirements-dev.lock ruff.toml rustfmt.toml . # TODO: Replace with recursive globbing, blocked on https://github.com/earthly/earthly/issues/1230 + COPY --dir benches . COPY --dir crates . COPY --dir pypi . RUN rye sync --no-lock @@ -67,17 +68,18 @@ lint: # Apply any automated fixes. fix: FROM +chef-cook - RUN cargo fmt --all + RUN cargo fmt --all --all-features RUN rye fmt --all RUN cargo clippy --fix --allow-no-vcs RUN rye lint --fix + SAVE ARTIFACT benches AS LOCAL ./ SAVE ARTIFACT crates AS LOCAL ./ SAVE ARTIFACT pypi AS LOCAL ./ # This target runs the tests. test: FROM +chef-cook - RUN cargo test --release --lib --bins --examples --tests --all-features + RUN cargo test -r -p abd-clam --all-features -p distances -p symagen # TODO: switch to --all, blocked on https://github.com/astral-sh/rye/issues/853 RUN rye test --package abd-distances diff --git a/README.md b/README.md index 62ef78ba7..6a06269ea 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ The Rust implementation of CLAM. As of writing this document, the project is still in a pre-1.0 state. This means that the API is not yet stable and breaking changes may occur frequently. -## Components +## Rust Crates and Python Packages This repository is a workspace that contains the following crates: @@ -16,14 +16,28 @@ and the following Python packages: - `abd-distances`: A Python wrapper for the `distances` crate, providing drop-in replacements for distance function `scipy.spatial.distance`. See [here](python/distances/README.md) for more information. -## License +## Reproducing Results from Papers -- MIT +This repository contains CLI tools to reproduce results from some of our papers. + +### CAKES + +This paper is currently under review at SIMODS. +See [here](benches/cakes/README.md) for running Rust code to reproduce the results for the CAKES algorithms, and [here](benches/py-cakes/README.md) for running some Python code to generate plots from the results of running the Rust code. + +### MSA + +TODO + +### PANCAKES + +TODO ## Publications -- [CHESS](https://arxiv.org/abs/1908.08551) -- [CHAODA](https://arxiv.org/abs/2103.11774) +- [CHESS](https://arxiv.org/abs/1908.08551): Hierarchical Clustering and Ranged Nearest Neighbors Search +- [CHAODA](https://arxiv.org/abs/2103.11774): Anomaly Detection +- [PANCAKES](https://arxiv.org/pdf/2409.12161): Compression and Compressive Search ## Citation diff --git a/benches/cakes/Cargo.toml b/benches/cakes/Cargo.toml new file mode 100644 index 000000000..a5ea32f1c --- /dev/null +++ b/benches/cakes/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "bench-cakes" +version = "0.1.0" +edition = "2021" + +[dependencies] +clap = { version = "4.5.16", features = ["derive"] } +bench-utils = { path = "../utils" } +ftlog = { workspace = true } +bitcode = { workspace = true } +abd-clam = { workspace = true, features = ["disk-io"] } +distances = { workspace = true } +rand = { workspace = true } +rayon = { workspace = true } +stringzilla = "3.9.5" +augurs-dtw = { version = "0.8.1", features = ["parallel"] } diff --git a/benches/cakes/src/README.md b/benches/cakes/src/README.md new file mode 100644 index 000000000..855b7765d --- /dev/null +++ b/benches/cakes/src/README.md @@ -0,0 +1,40 @@ +# Benchmarks for CAKES Search Algorithms + +This is crate provides a CLI to run benchmarks for the CAKES search algorithms and reproduce the results from our paper. + +## Reproducing the Results + +Let's say you have data from the [ANN-Benchmarks suite](https://github.com/erikbern/ann-benchmarks?tab=readme-ov-file#data-sets) in a directory `../data/input` and you want to run the benchmarks for the CAKES search algorithms on the `sift` dataset. +You can run the following command: + +```bash +cargo run --release --package bench-cakes -- \ + --inp-dir ../data/input/ \ + --dataset sift \ + --out-dir ../data/output/ \ + --seed 42 \ + --num-queries 10000 \ + --max-power 7 \ + --max-time 300 \ + --balanced-data \ + --permuted-trees +``` + +This will run the CAKES search algorithms on the `sift` dataset with 10000 search queries. +The results will be saved in the directory `../data/output/`. +The dataset will be augmented by powers of 2 up to 2^7. +Each algorithm will be run for at least 300 seconds. +The `--balanced` flag will build trees with balanced partitions. +The `--permuted` flag will permute the dataset into depth-first order after building the tree. + +There are several other available options. +Running the following command will provide documentation on how to use the CLI: + +```bash +cargo run --release --package bench-cakes -- --help +``` + +## Plotting the Results + +The outputs from the benchmarks can be plotted using the python package we provide at `../py-cakes`. +See the associated README for more information. diff --git a/benches/cakes/src/data_gen.rs b/benches/cakes/src/data_gen.rs new file mode 100644 index 000000000..0601e7756 --- /dev/null +++ b/benches/cakes/src/data_gen.rs @@ -0,0 +1,439 @@ +//! Utilities for generating and reading datasets and ground-truth search results. + +use abd_clam::{ + dataset::{DatasetIO, ParDataset}, + metric::ParMetric, + Dataset, FlatVec, +}; +use bench_utils::{ann_benchmarks::AnnDataset, Complex}; +use distances::Number; +use rand::prelude::*; +use rayon::prelude::*; + +/// Read the radio-ml dataset and subsamples it to a given maximum power of 2, +/// saving each subsample to a file. +#[allow(clippy::type_complexity)] +pub fn read_radio_ml_and_subsample + Send + Sync>( + inp_dir: &P, + out_dir: &P, + n_queries: usize, + max_power: u32, + seed: Option, + snr: Option, +) -> Result<(Vec>>, Vec), String> { + let name = "radio-ml"; // hard-coded until we add more datasets + let data_path = out_dir.as_ref().join(format!("{name}-0.flat_vec")); + let queries_path = out_dir.as_ref().join(format!("{name}-queries.bin")); + let mut all_paths = Vec::with_capacity(max_power as usize + 2); + for power in 1..=max_power { + all_paths.push(out_dir.as_ref().join(format!("{name}-{power}.flat_vec"))); + } + all_paths.push(data_path.clone()); + all_paths.push(queries_path.clone()); + + let queries = if all_paths.iter().all(|p| p.exists()) { + ftlog::info!("Subsampled datasets already exist. Reading queries from {queries_path:?}..."); + let bytes = std::fs::read(queries_path).map_err(|e| e.to_string())?; + bitcode::decode(&bytes).map_err(|e| e.to_string())? + } else { + ftlog::info!("Reading radio-ml dataset from {:?}...", inp_dir.as_ref()); + let mut rng = rand::rngs::StdRng::seed_from_u64(seed.unwrap_or(42)); + let modulation_modes = bench_utils::radio_ml::ModulationMode::all(); + let (signals, queries) = { + let signals = modulation_modes + .par_iter() + .map(|mode| bench_utils::radio_ml::read_mod(inp_dir, mode, snr)) + .collect::, _>>()?; + let mut signals = signals.into_iter().flatten().collect::>(); + signals.shuffle(&mut rng); + let queries = signals.split_off(n_queries); + (queries, signals) + }; + ftlog::info!("Read {} signals and {} queries.", signals.len(), queries.len()); + + ftlog::info!("Writing queries to {queries_path:?}..."); + let query_bytes = bitcode::encode(&queries).map_err(|e| e.to_string())?; + std::fs::write(queries_path, query_bytes).map_err(|e| e.to_string())?; + + let dim = signals[0].len(); + let mut data = FlatVec::new(signals)? + .with_name(name) + .with_dim_lower_bound(dim) + .with_dim_upper_bound(dim); + + ftlog::info!("Writing dataset to bitcode encoding: {data_path:?}"); + data.write_to(&data_path)?; + + for (power, path) in (1..=max_power).zip(all_paths.iter()) { + let size = data.cardinality() / 2; + ftlog::info!("Subsampling dataset with cardinality {size} to {path:?}..."); + data = data + .random_subsample(&mut rng, size) + .with_name(&format!("{name}-{power}")); + ftlog::info!("Writing subsampled dataset with cardinality {size} to {path:?}..."); + data.write_to(path)?; + } + + queries + }; + + all_paths.pop(); // remove queries_path + let full_path = all_paths + .pop() + .unwrap_or_else(|| unreachable!("We added the path ourselves.")); // remove data_path + all_paths.insert(0, full_path); + all_paths.reverse(); // return in ascending order of cardinality + + Ok((queries, all_paths)) +} + +/// Reads a fasta dataset and subsamples it to a given maximum power of 2, +/// saving each subsample to a file. +#[allow(clippy::type_complexity)] +pub fn read_fasta_and_subsample>( + inp_dir: &P, + out_dir: &P, + remove_gaps: bool, + n_queries: usize, + max_power: u32, + seed: Option, +) -> Result<(Vec<(String, String)>, Vec), String> { + let name = "silva-SSU-Ref"; // hard-coded until we add more datasets + let fasta_path = inp_dir.as_ref().join(format!("{name}.fasta")); + let data_path = out_dir.as_ref().join(format!("{name}-0.flat_vec")); + let queries_path = out_dir.as_ref().join(format!("{name}-queries.bin")); + let mut all_paths = Vec::with_capacity(max_power as usize + 2); + for power in 1..=max_power { + all_paths.push(out_dir.as_ref().join(format!("{name}-{power}.flat_vec"))); + } + all_paths.push(data_path.clone()); + all_paths.push(queries_path.clone()); + + let queries = if all_paths.iter().all(|p| p.exists()) { + ftlog::info!("Subsampled datasets already exist. Reading queries from {queries_path:?}..."); + let bytes = std::fs::read(queries_path).map_err(|e| e.to_string())?; + bitcode::decode(&bytes).map_err(|e| e.to_string())? + } else { + if !fasta_path.exists() { + return Err(format!("Dataset {name} not found: {fasta_path:?}")); + } + + ftlog::info!("Reading fasta dataset from fasta file: {fasta_path:?}"); + let (mut data, queries) = bench_utils::fasta::read(&fasta_path, n_queries, remove_gaps)?; + data = data.with_name(name); + + ftlog::info!("Writing dataset to bitcode encoding: {data_path:?}"); + data.write_to(&data_path)?; + + let query_bytes = bitcode::encode(&queries).map_err(|e| e.to_string())?; + std::fs::write(queries_path, query_bytes).map_err(|e| e.to_string())?; + + let mut rng = rand::rngs::StdRng::seed_from_u64(seed.unwrap_or(42)); + for (power, path) in (1..=max_power).zip(all_paths.iter()) { + let size = data.cardinality() / 2; + ftlog::info!("Subsampling dataset with cardinality {size} to {path:?}..."); + data = data + .random_subsample(&mut rng, size) + .with_name(&format!("{name}-{power}")); + ftlog::info!("Writing subsampled dataset with cardinality {size} to {path:?}..."); + data.write_to(path)?; + } + + queries + }; + + all_paths.pop(); // remove queries_path + let full_path = all_paths + .pop() + .unwrap_or_else(|| unreachable!("We added the path ourselves.")); // remove data_path + all_paths.insert(0, full_path); + all_paths.reverse(); // return in ascending order of cardinality + + Ok((queries, all_paths)) +} + +/// Read the tabular floating-point dataset and augment it to a given maximum +/// power of 2. +/// +/// # Arguments +/// +/// - `dataset`: The dataset to read. +/// - `metric`: The metric, if provided, to use for linear search to find true +/// neighbors of augmented datasets. +/// - `max_power`: The maximum power of 2 to which the cardinality of the +/// dataset should be augmented. +/// - `seed`: The seed for the random number generator. +/// - `inp_dir`: The directory containing the input dataset. +/// - `out_dir`: The directory to which the augmented datasets and ground-truth +/// neighbors and distances should be written. +/// +/// # Returns +/// +/// The queries to use for benchmarking. +/// +/// # Errors +/// +/// - If there is an error reading the input dataset. +/// - If there is an error writing the augmented datasets. +#[allow(clippy::too_many_lines)] +pub fn read_tabular_and_augment, M: ParMetric, f32>>( + dataset: &bench_utils::RawData, + metric: Option<&M>, + max_power: u32, + seed: Option, + inp_dir: &P, + out_dir: &P, +) -> Result>, String> { + let ([data_orig_path, queries_path, neighbors_path, distances_path], all_paths) = + gen_all_paths(dataset, max_power, out_dir); + + if all_paths.iter().all(|p| p.exists()) { + ftlog::info!("Augmented datasets already exist. Reading queries from {queries_path:?}..."); + return FlatVec::, usize>::read_npy(&queries_path).map(FlatVec::take_items); + } + + ftlog::info!("Reading data from {:?}...", inp_dir.as_ref()); + let data = dataset.read_vector::<_, f32>(&inp_dir)?; + let (train, queries, neighbors) = (data.train, data.queries, data.neighbors); + let (neighbors, distances): (Vec<_>, Vec<_>) = neighbors + .into_iter() + .map(|n| { + let (n, d): (Vec<_>, Vec<_>) = n.into_iter().unzip(); + let n = n.into_iter().map(Number::as_u64).collect::>(); + (n, d) + }) + .unzip(); + let k = neighbors[0].len(); + let neighbors = FlatVec::new(neighbors)?.with_dim_lower_bound(k).with_dim_upper_bound(k); + let distances = FlatVec::new(distances)?.with_dim_lower_bound(k).with_dim_upper_bound(k); + + let (min_dim, max_dim) = train + .iter() + .chain(queries.iter()) + .fold((usize::MAX, 0), |(min, max), x| { + (Ord::min(min, x.len()), Ord::max(max, x.len())) + }); + + let data = FlatVec::new(train)? + .with_name(dataset.name()) + .with_dim_lower_bound(min_dim) + .with_dim_upper_bound(max_dim); + + ftlog::info!("Writing original data as npy array to {data_orig_path:?}..."); + data.write_npy(&data_orig_path)?; + + let query_data = FlatVec::new(queries)? + .with_name(&format!("{}-queries", dataset.name())) + .with_dim_lower_bound(min_dim) + .with_dim_upper_bound(max_dim); + ftlog::info!("Writing queries as npy array to {queries_path:?}..."); + query_data.write_npy(&queries_path)?; + let queries = query_data.take_items(); + + ftlog::info!("Writing neighbors to {neighbors_path:?}..."); + neighbors.write_npy(&neighbors_path)?; + distances.write_npy(&distances_path)?; + + ftlog::info!("Augmenting data..."); + let train = data.take_items(); + let base_cardinality = train.len(); + let data = AnnDataset { + train, + queries: Vec::new(), + neighbors: Vec::new(), + } + .augment(1 << max_power, 0.1); + + // The value of k is hardcoded to 100 to find the true neighbors of the + // augmented datasets. + let k = 100; + + let mut data = FlatVec::new(data.train)? + .with_dim_lower_bound(min_dim) + .with_dim_upper_bound(max_dim); + + for power in (1..=max_power).rev() { + let name = format!("{}-{power}", dataset.name()); + let data_path = out_dir.as_ref().join(format!("{name}.npy")); + let neighbors_path = out_dir.as_ref().join(format!("{name}-neighbors.npy")); + let distances_path = out_dir.as_ref().join(format!("{name}-distances.npy")); + + let size = base_cardinality * (1 << power); + let mut rng = rand::rngs::StdRng::seed_from_u64(seed.unwrap_or(42)); + data = data.random_subsample(&mut rng, size).with_name(&name); + ftlog::info!("Writing {}x augmented data to {data_path:?}...", 1 << power); + data.write_npy(&data_path)?; + + if let Some(metric) = metric { + ftlog::info!("Finding true neighbors for {name}..."); + let indices = (0..data.cardinality()).collect::>(); + let true_hits = queries + .par_iter() + .map(|query| { + let mut hits = data.par_query_to_many(query, &indices, metric).collect::>(); + hits.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + let _ = hits.split_off(k); + hits + }) + .collect::>(); + ftlog::info!("Writing true neighbors to {neighbors_path:?} and distances to {distances_path:?}..."); + let (neighbors, distances): (Vec<_>, Vec<_>) = true_hits + .into_iter() + .map(|mut nds| { + // Sort the neighbors by distance from the query. + nds.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + + let (n, d): (Vec<_>, Vec<_>) = nds.into_iter().unzip(); + let n = n.into_iter().map(Number::as_u64).collect::>(); + + (n, d) + }) + .unzip(); + FlatVec::new(neighbors)? + .with_dim_lower_bound(k) + .with_dim_upper_bound(k) + .write_npy(&neighbors_path)?; + FlatVec::new(distances)? + .with_dim_lower_bound(k) + .with_dim_upper_bound(k) + .write_npy(&distances_path)?; + } + } + + Ok(queries) +} + +/// Read or generate random datasets and ground-truth search results. +/// +/// # Arguments +/// +/// - `metric`: The metric, if provided, to use for linear search to find true +/// neighbors of augmented datasets. +/// - `max_power`: The maximum power of 2 to which the cardinality of the +/// dataset should be augmented. +/// - `seed`: The seed for the random number generator. +/// - `out_dir`: The directory to which the augmented datasets and ground-truth +/// neighbors and distances should be written. +/// +/// # Returns +/// +/// The queries to use for benchmarking. +pub fn read_or_gen_random, M: ParMetric, f32>>( + metric: Option<&M>, + max_power: u32, + seed: Option, + out_dir: &P, +) -> Result>, String> { + let dataset = bench_utils::RawData::Random; + let ([_, queries_path, _, _], all_paths) = gen_all_paths(&dataset, max_power, out_dir); + + if all_paths.iter().all(|p| p.exists()) { + ftlog::info!("Random datasets already exist. Reading queries from {queries_path:?}..."); + return FlatVec::, usize>::read_npy(&queries_path).map(FlatVec::take_items); + } + let k = 100; + let n_queries = 100; + let base_cardinality = 1_000_000; + let dimensionality = 128; + let data = AnnDataset::gen_random(base_cardinality, 1 << max_power, dimensionality, n_queries, 42); + let (train, queries, _) = (data.train, data.queries, data.neighbors); + + let queries = FlatVec::new(queries)? + .with_dim_lower_bound(dimensionality) + .with_dim_upper_bound(dimensionality); + let queries_path = out_dir.as_ref().join(format!("{}-queries.npy", dataset.name())); + ftlog::info!("Writing queries as npy array to {queries_path:?}..."); + queries.write_npy(&queries_path)?; + let queries = queries.take_items(); + + let mut data = FlatVec::new(train)? + .with_dim_lower_bound(dimensionality) + .with_dim_upper_bound(dimensionality); + for power in (0..=max_power).rev() { + let name = format!("{}-{}", dataset.name(), power); + let data_path = out_dir.as_ref().join(format!("{name}.npy")); + let neighbors_path = out_dir.as_ref().join(format!("{name}-neighbors.npy")); + let distances_path = out_dir.as_ref().join(format!("{name}-distances.npy")); + + let size = base_cardinality * (1 << power); + let mut rng = rand::rngs::StdRng::seed_from_u64(seed.unwrap_or(42)); + data = data.random_subsample(&mut rng, size).with_name(&name); + ftlog::info!("Writing {}x random data to {data_path:?}...", 1 << power); + data.write_npy(&data_path)?; + + if let Some(metric) = metric { + ftlog::info!("Finding true neighbors for {name}..."); + let indices = (0..data.cardinality()).collect::>(); + let true_hits = queries + .par_iter() + .map(|query| { + let mut hits = data.par_query_to_many(query, &indices, metric).collect::>(); + hits.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + let _ = hits.split_off(k); + hits + }) + .collect::>(); + ftlog::info!("Writing true neighbors to {neighbors_path:?} and distances to {distances_path:?}..."); + let (neighbors, distances): (Vec<_>, Vec<_>) = true_hits + .into_iter() + .map(|mut nds| { + // Sort the neighbors by distance from the query. + nds.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + + let (n, d): (Vec<_>, Vec<_>) = nds.into_iter().unzip(); + let n = n.into_iter().map(Number::as_u64).collect::>(); + + (n, d) + }) + .unzip(); + FlatVec::new(neighbors)? + .with_dim_lower_bound(k) + .with_dim_upper_bound(k) + .write_npy(&neighbors_path)?; + FlatVec::new(distances)? + .with_dim_lower_bound(k) + .with_dim_upper_bound(k) + .write_npy(&distances_path)?; + } + } + + Ok(queries) +} + +/// Generate all paths for the augmented datasets and ground-truth neighbors and +/// distances. +fn gen_all_paths>( + dataset: &bench_utils::RawData, + max_power: u32, + out_dir: &P, +) -> ([std::path::PathBuf; 4], Vec) { + let data_orig_path = out_dir.as_ref().join(format!("{}-0.npy", dataset.name())); + let queries_path = out_dir.as_ref().join(format!("{}-queries.npy", dataset.name())); + let neighbors_path = out_dir.as_ref().join(format!("{}-0-neighbors.npy", dataset.name())); + let distances_path = out_dir.as_ref().join(format!("{}-0-distances.npy", dataset.name())); + + let all_paths = { + let mut paths = Vec::with_capacity(max_power as usize + 3); + for power in 1..=max_power { + paths.push(out_dir.as_ref().join(format!("{}-{}.npy", dataset.name(), power))); + paths.push( + out_dir + .as_ref() + .join(format!("{}-{}-neighbors.npy", dataset.name(), power)), + ); + paths.push( + out_dir + .as_ref() + .join(format!("{}-{}-distances.npy", dataset.name(), power)), + ); + } + paths.push(data_orig_path.clone()); + paths.push(queries_path.clone()); + paths + }; + + ( + [data_orig_path, queries_path, neighbors_path, distances_path], + all_paths, + ) +} diff --git a/benches/cakes/src/main.rs b/benches/cakes/src/main.rs new file mode 100644 index 000000000..d8d53eff4 --- /dev/null +++ b/benches/cakes/src/main.rs @@ -0,0 +1,310 @@ +#![deny(clippy::correctness)] +#![warn( + missing_docs, + clippy::all, + clippy::suspicious, + clippy::style, + clippy::complexity, + clippy::perf, + clippy::pedantic, + clippy::nursery, + clippy::missing_docs_in_private_items, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::cast_lossless +)] +//! Benchmarks for the CAKES paper. + +use std::path::PathBuf; + +use abd_clam::{dataset::DatasetIO, Dataset, FlatVec}; +use bench_utils::Complex; +use clap::Parser; + +mod data_gen; +mod metric; +mod search; +mod trees; +mod workflow; + +use distances::Number; +use metric::{CountingMetric, ParCountingMetric}; + +/// Reproducible results for the CAKES paper. +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +#[allow(clippy::struct_excessive_bools)] +struct Args { + /// Path to the input file. + #[arg(short('i'), long)] + inp_dir: PathBuf, + + /// The dataset to benchmark. + #[arg(short('d'), long)] + dataset: bench_utils::RawData, + + /// The number of queries to use for benchmarking. + #[arg(short('q'), long)] + num_queries: usize, + + /// Whether to count the number of distance computations during search. + #[arg(short('c'), long, default_value = "false")] + count_distance_calls: bool, + + /// This parameter is used differently depending on the dataset: + /// + /// - For any vector datasets, this is the maximum power of 2 to which the + /// cardinality should be augmented for scaling experiments. + /// - For 'omic datasets, this is the maximum power of 2 by which the + /// cardinality should be divided (sub-sampled) for scaling experiments. + /// - For the complex-valued radio-ml dataset, this works identically as + /// with the sequence datasets. + /// - For set datasets (kosarak, etc.), this is ignored. + #[arg(short('m'), long)] + max_power: Option, + + /// The minimum power of 2 to which the cardinality of the dataset should be + /// augmented for scaling experiments. + /// + /// This is only used with the tabular floating-point datasets and is + /// ignored otherwise. + #[arg(short('n'), long, default_value = "0")] + min_power: Option, + + /// The seed for the random number generator. + #[arg(short('s'), long)] + seed: Option, + + /// The maximum time, in seconds, to run each algorithm. + #[arg(short('t'), long, default_value = "10.0")] + max_time: f32, + + /// Whether to run benchmarks with balanced trees. + #[arg(short('b'), long)] + balanced_trees: bool, + + /// Whether to run benchmarks with permuted data. + #[arg(short('p'), long)] + permuted_data: bool, + + /// Whether to run ranged search benchmarks. + #[arg(short('r'), long)] + ranged_search: bool, + + /// Path to the output directory. + #[arg(short('o'), long)] + out_dir: Option, + + /// Stop after generating the augmented datasets. + #[arg(short('g'), long)] + generate_only: bool, + + /// Whether to run linear search on the datasets to find the ground truth. + #[arg(short('l'), long)] + linear_search: bool, + + /// Whether to rebuild the trees. + #[arg(short('w'), long)] + rebuild_trees: bool, +} + +#[allow(clippy::too_many_lines, clippy::cognitive_complexity)] +fn main() -> Result<(), String> { + let args = Args::parse(); + println!("Args: {args:?}"); + + let log_name = format!("cakes-{}", args.dataset.name()); + let (_guard, log_path) = bench_utils::configure_logger(&log_name)?; + println!("Log file: {log_path:?}"); + + ftlog::info!("{args:?}"); + + // Check the input and output directories. + let inp_dir = args.inp_dir.canonicalize().map_err(|e| e.to_string())?; + ftlog::info!("Input directory: {inp_dir:?}"); + + let out_dir = if let Some(out_dir) = args.out_dir { + out_dir + } else { + ftlog::info!("No output directory specified. Using default."); + let mut out_dir = inp_dir + .parent() + .ok_or("No parent directory of `inp_dir`")? + .to_path_buf(); + out_dir.push(format!("{}_results", args.dataset.name())); + if !out_dir.exists() { + std::fs::create_dir(&out_dir).map_err(|e| e.to_string())?; + } + out_dir + } + .canonicalize() + .map_err(|e| e.to_string())?; + ftlog::info!("Output directory: {out_dir:?}"); + + let radial_fractions = [0.1, 0.25]; + let ks = [10, 100]; + let seed = args.seed; + let min_power = args.min_power.unwrap_or_default(); + let max_power = args.max_power.unwrap_or(5); + let max_time = std::time::Duration::from_secs_f32(args.max_time); + + if min_power > max_power { + return Err("min_power must be less than or equal to max_power".to_string()); + } + + if args.dataset.is_tabular() { + let metric = { + let mut metric: Box> = match args.dataset.metric() { + "cosine" => Box::new(metric::Cosine::new()), + "euclidean" => Box::new(metric::Euclidean::new()), + _ => return Err(format!("Unknown metric: {}", args.dataset.metric())), + }; + if !args.count_distance_calls { + metric.disable_counting(); + } + metric + }; + + let gt_metric = if args.linear_search { Some(&metric) } else { None }; + let queries = if matches!(args.dataset, bench_utils::RawData::Random) { + data_gen::read_or_gen_random(gt_metric, max_power, seed, &out_dir)? + } else { + data_gen::read_tabular_and_augment(&args.dataset, gt_metric, max_power, args.seed, &inp_dir, &out_dir)? + }; + if args.generate_only { + return Ok(()); + } + + for power in min_power..=max_power { + let data_name = format!("{}-{power}", args.dataset.name()); + ftlog::info!("Reading {}x augmented data...", 1 << power); + let run_linear = power < 4; + + workflow::run_tabular( + &out_dir, + &data_name, + &metric, + &queries, + args.num_queries, + &radial_fractions, + &ks, + seed, + max_time, + run_linear, + args.balanced_trees, + args.permuted_data, + args.ranged_search, + args.rebuild_trees, + )?; + } + } else if args.dataset.is_sequence() { + let metric = { + let mut metric: Box> = match args.dataset.metric() { + "levenshtein" => Box::new(metric::Levenshtein::new()), + "hamming" => Box::new(metric::Hamming::new()), + _ => return Err(format!("Unknown metric: {}", args.dataset.metric())), + }; + if !args.count_distance_calls { + metric.disable_counting(); + } + metric + }; + + let (queries, subsampled_paths) = + data_gen::read_fasta_and_subsample(&inp_dir, &out_dir, false, args.num_queries, max_power, seed)?; + let queries = queries.into_iter().map(|(_, q)| q).collect::>(); + if args.generate_only { + return Ok(()); + } + + ftlog::info!("Found {} sub-sampled datasets:", subsampled_paths.len()); + for p in &subsampled_paths { + ftlog::info!("{p:?}"); + } + + for (i, sample_path) in subsampled_paths.iter().enumerate() { + if i.as_u32() < min_power { + continue; + } + + ftlog::info!("Reading sub-sampled data from {sample_path:?}..."); + let data = FlatVec::::read_from(sample_path)?; + ftlog::info!("Data from {sample_path:?} has {} sequences...", data.cardinality()); + + let run_linear = i < 4; + + ftlog::info!("Running workflow for {}...", data.name()); + workflow::run_fasta( + &out_dir, + &data, + &metric, + &queries, + &radial_fractions, + &ks, + seed, + max_time, + run_linear, + args.balanced_trees, + args.permuted_data, + args.ranged_search, + args.rebuild_trees, + )?; + } + } else if matches!(args.dataset, bench_utils::RawData::RadioML) { + let metric = { + let mut metric = metric::DynamicTimeWarping::new(); + if !args.count_distance_calls { + >, f64>>::disable_counting(&mut metric); + } + metric + }; + + let snr = Some(10); + let (queries, subsampled_paths) = + data_gen::read_radio_ml_and_subsample(&inp_dir, &out_dir, args.num_queries, max_power, seed, snr)?; + if args.generate_only { + return Ok(()); + } + + ftlog::info!("Found {} sub-sampled datasets:", subsampled_paths.len()); + for p in &subsampled_paths { + ftlog::info!("{p:?}"); + } + + for (i, sample_path) in subsampled_paths.iter().enumerate() { + if i.as_u32() < min_power { + continue; + } + + ftlog::info!("Reading sub-sampled data from {sample_path:?}..."); + let data = FlatVec::>, usize>::read_from(sample_path)?; + ftlog::info!("Data from {sample_path:?} has {} signals...", data.cardinality()); + + let run_linear = i < 3; + + ftlog::info!("Running workflow for {}...", data.name()); + workflow::run_radio_ml( + &out_dir, + &data, + &metric, + &queries, + &radial_fractions, + &ks, + seed, + max_time, + run_linear, + args.balanced_trees, + args.permuted_data, + args.ranged_search, + args.rebuild_trees, + )?; + } + } else { + let msg = format!("Unsupported dataset: {}", args.dataset.name()); + ftlog::error!("{msg}"); + return Err(msg); + } + + Ok(()) +} diff --git a/benches/cakes/src/metric/cosine.rs b/benches/cakes/src/metric/cosine.rs new file mode 100644 index 000000000..febdda56c --- /dev/null +++ b/benches/cakes/src/metric/cosine.rs @@ -0,0 +1,147 @@ +//! The Cosine distance function. + +use std::sync::{Arc, RwLock}; + +use abd_clam::{metric::ParMetric, Metric}; + +use super::{CountingMetric, ParCountingMetric}; + +/// The Cosine distance function. +pub struct Cosine(Arc>, bool); + +impl Cosine { + /// Creates a new `Euclidean` distance metric. + pub fn new() -> Self { + Self(Arc::new(RwLock::new(0)), false) + } +} + +impl> Metric for Cosine { + fn distance(&self, a: &I, b: &I) -> f32 { + if self.1 { + >::increment(self); + } + distances::simd::cosine_f32(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "cosine" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl> CountingMetric for Cosine { + fn disable_counting(&mut self) { + self.1 = false; + } + + fn enable_counting(&mut self) { + self.1 = true; + } + + #[allow(clippy::unwrap_used)] + fn count(&self) -> usize { + *self.0.read().unwrap() + } + + #[allow(clippy::unwrap_used)] + fn reset_count(&self) -> usize { + let mut count = self.0.write().unwrap(); + let old = *count; + *count = 0; + old + } + + #[allow(clippy::unwrap_used)] + fn increment(&self) { + *self.0.write().unwrap() += 1; + } +} + +impl + Send + Sync> ParMetric for Cosine {} + +impl + Send + Sync> ParCountingMetric for Cosine {} + +impl> Metric for Cosine { + fn distance(&self, a: &I, b: &I) -> f64 { + if self.1 { + >::increment(self); + } + distances::simd::cosine_f64(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "cosine" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl> CountingMetric for Cosine { + fn disable_counting(&mut self) { + self.1 = false; + } + + fn enable_counting(&mut self) { + self.1 = true; + } + + #[allow(clippy::unwrap_used)] + fn count(&self) -> usize { + *self.0.read().unwrap() + } + + #[allow(clippy::unwrap_used)] + fn reset_count(&self) -> usize { + let mut count = self.0.write().unwrap(); + let old = *count; + *count = 0; + old + } + + #[allow(clippy::unwrap_used)] + fn increment(&self) { + *self.0.write().unwrap() += 1; + } +} + +impl + Send + Sync> ParMetric for Cosine {} + +impl + Send + Sync> ParCountingMetric for Cosine {} diff --git a/benches/cakes/src/metric/dtw.rs b/benches/cakes/src/metric/dtw.rs new file mode 100644 index 000000000..3e5c415a0 --- /dev/null +++ b/benches/cakes/src/metric/dtw.rs @@ -0,0 +1,90 @@ +//! The `DynamicTimeWarping` distance metric. + +use std::sync::{Arc, RwLock}; + +use abd_clam::{metric::ParMetric, Metric}; +use bench_utils::{metrics::dtw_distance, Complex}; + +use super::{CountingMetric, ParCountingMetric}; + +/// The `DynamicTimeWarping` distance metric. +pub struct DynamicTimeWarping(Arc>, bool); + +impl DynamicTimeWarping { + /// Creates a new `DynamicTimeWarping` distance metric. + pub fn new() -> Self { + Self(Arc::new(RwLock::new(0)), true) + } +} + +impl]>> Metric for DynamicTimeWarping { + fn distance(&self, a: &I, b: &I) -> f64 { + if self.1 { + >::increment(self); + } + dtw_distance(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "dtw" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl]>> CountingMetric for DynamicTimeWarping { + fn disable_counting(&mut self) { + self.1 = false; + } + + fn enable_counting(&mut self) { + self.1 = true; + } + + #[allow(clippy::unwrap_used)] + fn count(&self) -> usize { + *self.0.read().unwrap() + } + + #[allow(clippy::unwrap_used)] + fn reset_count(&self) -> usize { + let mut count = self.0.write().unwrap(); + let old = *count; + *count = 0; + old + } + + #[allow(clippy::unwrap_used)] + fn increment(&self) { + *self.0.write().unwrap() += 1; + } +} + +impl]> + Send + Sync> ParMetric for DynamicTimeWarping { + fn par_distance(&self, a: &I, b: &I) -> f64 { + if self.1 { + >::increment(self); + } + dtw_distance(a.as_ref(), b.as_ref()) + } +} + +impl]> + Send + Sync> ParCountingMetric for DynamicTimeWarping {} diff --git a/benches/cakes/src/metric/euclidean.rs b/benches/cakes/src/metric/euclidean.rs new file mode 100644 index 000000000..d4ac662d2 --- /dev/null +++ b/benches/cakes/src/metric/euclidean.rs @@ -0,0 +1,147 @@ +//! The `Euclidean` distance metric. + +use std::sync::{Arc, RwLock}; + +use abd_clam::{metric::ParMetric, Metric}; + +use super::{CountingMetric, ParCountingMetric}; + +/// The `Euclidean` distance metric. +pub struct Euclidean(Arc>, bool); + +impl Euclidean { + /// Creates a new `Euclidean` distance metric. + pub fn new() -> Self { + Self(Arc::new(RwLock::new(0)), true) + } +} + +impl> Metric for Euclidean { + fn distance(&self, a: &I, b: &I) -> f32 { + if self.1 { + >::increment(self); + } + distances::simd::euclidean_f32(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "euclidean" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl> CountingMetric for Euclidean { + fn disable_counting(&mut self) { + self.1 = false; + } + + fn enable_counting(&mut self) { + self.1 = true; + } + + #[allow(clippy::unwrap_used)] + fn count(&self) -> usize { + *self.0.read().unwrap() + } + + #[allow(clippy::unwrap_used)] + fn reset_count(&self) -> usize { + let mut count = self.0.write().unwrap(); + let old = *count; + *count = 0; + old + } + + #[allow(clippy::unwrap_used)] + fn increment(&self) { + *self.0.write().unwrap() += 1; + } +} + +impl + Send + Sync> ParMetric for Euclidean {} + +impl + Send + Sync> ParCountingMetric for Euclidean {} + +impl> Metric for Euclidean { + fn distance(&self, a: &I, b: &I) -> f64 { + if self.1 { + >::increment(self); + } + distances::simd::euclidean_f64(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "euclidean" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl> CountingMetric for Euclidean { + fn disable_counting(&mut self) { + self.1 = false; + } + + fn enable_counting(&mut self) { + self.1 = true; + } + + #[allow(clippy::unwrap_used)] + fn count(&self) -> usize { + *self.0.read().unwrap() + } + + #[allow(clippy::unwrap_used)] + fn reset_count(&self) -> usize { + let mut count = self.0.write().unwrap(); + let old = *count; + *count = 0; + old + } + + #[allow(clippy::unwrap_used)] + fn increment(&self) { + *self.0.write().unwrap() += 1; + } +} + +impl + Send + Sync> ParMetric for Euclidean {} + +impl + Send + Sync> ParCountingMetric for Euclidean {} diff --git a/benches/cakes/src/metric/hamming.rs b/benches/cakes/src/metric/hamming.rs new file mode 100644 index 000000000..0bc82a112 --- /dev/null +++ b/benches/cakes/src/metric/hamming.rs @@ -0,0 +1,82 @@ +//! The `Hamming` distance metric. + +use std::sync::{Arc, RwLock}; + +use abd_clam::{metric::ParMetric, Metric}; + +use super::{CountingMetric, ParCountingMetric}; + +/// The `Hamming` distance metric. +pub struct Hamming(Arc>, bool); + +impl Hamming { + /// Creates a new `Hamming` distance metric. + pub fn new() -> Self { + Self(Arc::new(RwLock::new(0)), true) + } +} + +impl Metric for Hamming { + fn distance(&self, a: &String, b: &String) -> u32 { + if self.1 { + self.increment(); + } + distances::strings::hamming(a, b) + } + + fn name(&self) -> &str { + "hamming" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl CountingMetric for Hamming { + fn disable_counting(&mut self) { + self.1 = false; + } + + fn enable_counting(&mut self) { + self.1 = true; + } + + #[allow(clippy::unwrap_used)] + fn count(&self) -> usize { + *self.0.read().unwrap() + } + + #[allow(clippy::unwrap_used)] + fn reset_count(&self) -> usize { + let mut count = self.0.write().unwrap(); + let old = *count; + *count = 0; + old + } + + #[allow(clippy::unwrap_used)] + fn increment(&self) { + *self.0.write().unwrap() += 1; + } +} + +impl ParMetric for Hamming {} + +impl ParCountingMetric for Hamming {} diff --git a/benches/cakes/src/metric/levenshtein.rs b/benches/cakes/src/metric/levenshtein.rs new file mode 100644 index 000000000..32b5b5990 --- /dev/null +++ b/benches/cakes/src/metric/levenshtein.rs @@ -0,0 +1,83 @@ +//! The `Levenshtein` distance metric. + +use std::sync::{Arc, RwLock}; + +use abd_clam::{metric::ParMetric, Metric}; +use distances::Number; + +use super::{CountingMetric, ParCountingMetric}; + +/// The `Levenshtein` distance metric. +pub struct Levenshtein(Arc>, bool); + +impl Levenshtein { + /// Creates a new `Levenshtein` distance metric. + pub fn new() -> Self { + Self(Arc::new(RwLock::new(0)), true) + } +} + +impl Metric for Levenshtein { + fn distance(&self, a: &String, b: &String) -> u32 { + if self.1 { + self.increment(); + } + stringzilla::sz::edit_distance(a, b).as_u32() + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl CountingMetric for Levenshtein { + fn disable_counting(&mut self) { + self.1 = false; + } + + fn enable_counting(&mut self) { + self.1 = true; + } + + #[allow(clippy::unwrap_used)] + fn count(&self) -> usize { + *self.0.read().unwrap() + } + + #[allow(clippy::unwrap_used)] + fn reset_count(&self) -> usize { + let mut count = self.0.write().unwrap(); + let old = *count; + *count = 0; + old + } + + #[allow(clippy::unwrap_used)] + fn increment(&self) { + *self.0.write().unwrap() += 1; + } +} + +impl ParMetric for Levenshtein {} + +impl ParCountingMetric for Levenshtein {} diff --git a/benches/cakes/src/metric/mod.rs b/benches/cakes/src/metric/mod.rs new file mode 100644 index 000000000..5f1d549b9 --- /dev/null +++ b/benches/cakes/src/metric/mod.rs @@ -0,0 +1,147 @@ +//! Metrics that count the number of distance computations. + +use abd_clam::{metric::ParMetric, Metric}; +use distances::Number; + +mod cosine; +mod dtw; +mod euclidean; +mod hamming; +mod levenshtein; + +pub use cosine::Cosine; +pub use dtw::DynamicTimeWarping; +pub use euclidean::Euclidean; +pub use hamming::Hamming; +pub use levenshtein::Levenshtein; + +/// A metric that counts the number of distance computations. +#[allow(clippy::module_name_repetitions)] +pub trait CountingMetric: Metric { + /// Disables counting the number of distance computations. + fn disable_counting(&mut self); + + /// Enables counting the number of distance computations. + fn enable_counting(&mut self); + + /// Returns the number of distance computations made. + fn count(&self) -> usize; + + /// Resets the counter and returns the previous value. + fn reset_count(&self) -> usize; + + /// Increments the counter. + fn increment(&self); +} + +/// Parallel version of the `CountingMetric` trait. +#[allow(clippy::module_name_repetitions)] +pub trait ParCountingMetric: ParMetric + CountingMetric {} + +impl Metric for Box> { + fn distance(&self, a: &I, b: &I) -> T { + self.as_ref().distance(a, b) + } + + fn name(&self) -> &str { + self.as_ref().name() + } + + fn has_identity(&self) -> bool { + self.as_ref().has_identity() + } + + fn has_non_negativity(&self) -> bool { + self.as_ref().has_non_negativity() + } + + fn has_symmetry(&self) -> bool { + self.as_ref().has_symmetry() + } + + fn obeys_triangle_inequality(&self) -> bool { + self.as_ref().obeys_triangle_inequality() + } + + fn is_expensive(&self) -> bool { + self.as_ref().is_expensive() + } +} + +impl Metric for Box> { + fn distance(&self, a: &I, b: &I) -> T { + self.as_ref().distance(a, b) + } + + fn name(&self) -> &str { + self.as_ref().name() + } + + fn has_identity(&self) -> bool { + self.as_ref().has_identity() + } + + fn has_non_negativity(&self) -> bool { + self.as_ref().has_non_negativity() + } + + fn has_symmetry(&self) -> bool { + self.as_ref().has_symmetry() + } + + fn obeys_triangle_inequality(&self) -> bool { + self.as_ref().obeys_triangle_inequality() + } + + fn is_expensive(&self) -> bool { + self.as_ref().is_expensive() + } +} + +impl ParMetric for Box> {} + +impl CountingMetric for Box> { + fn disable_counting(&mut self) { + self.as_mut().disable_counting(); + } + + fn enable_counting(&mut self) { + self.as_mut().enable_counting(); + } + + fn count(&self) -> usize { + self.as_ref().count() + } + + fn reset_count(&self) -> usize { + self.as_ref().reset_count() + } + + fn increment(&self) { + self.as_ref().increment(); + } +} + +impl CountingMetric for Box> { + fn disable_counting(&mut self) { + self.as_mut().disable_counting(); + } + + fn enable_counting(&mut self) { + self.as_mut().enable_counting(); + } + + fn count(&self) -> usize { + self.as_ref().count() + } + + fn reset_count(&self) -> usize { + self.as_ref().reset_count() + } + + fn increment(&self) { + self.as_ref().increment(); + } +} + +impl ParCountingMetric for Box> {} diff --git a/benches/cakes/src/search.rs b/benches/cakes/src/search.rs new file mode 100644 index 000000000..da33a1ea9 --- /dev/null +++ b/benches/cakes/src/search.rs @@ -0,0 +1,221 @@ +//! Helpers for using the search modules from `abd_clam` in benchmarks. + +use core::time::Duration; + +use abd_clam::{cakes::ParSearchAlgorithm, cluster::ParCluster, Dataset, FlatVec}; +use bench_utils::reports::CakesResults; +use distances::Number; + +use crate::metric::ParCountingMetric; + +/// Run all the search algorithms on a cluster. +#[allow(clippy::too_many_arguments, clippy::fn_params_excessive_bools)] +pub fn bench_all_algs( + report: &mut CakesResults, + metric: &M, + queries: &[I], + neighbors: Option<&[Vec<(usize, T)>]>, + root: &C, + data: &FlatVec, + is_balanced: bool, + is_permuted: bool, + max_time: Duration, + ks: &[usize], + radii: &[T], + run_linear: bool, + ranged_search: bool, +) where + I: Send + Sync, + T: Number, + M: ParCountingMetric, + C: ParCluster, + Me: Send + Sync, +{ + if ranged_search { + for (i, &radius) in radii.iter().enumerate() { + if i == 0 && run_linear { + bench_algorithm( + report, + metric, + queries, + neighbors, + &abd_clam::cakes::RnnLinear(radius), + root, + data, + is_balanced, + is_permuted, + max_time, + ); + } + + bench_algorithm( + report, + metric, + queries, + neighbors, + &abd_clam::cakes::RnnClustered(radius), + root, + data, + is_balanced, + is_permuted, + max_time, + ); + } + } + + for (i, &k) in ks.iter().enumerate() { + if i == 0 && run_linear { + bench_algorithm( + report, + metric, + queries, + neighbors, + &abd_clam::cakes::KnnLinear(k), + root, + data, + is_balanced, + is_permuted, + max_time, + ); + } + + bench_algorithm( + report, + metric, + queries, + neighbors, + &abd_clam::cakes::KnnRepeatedRnn(k, T::ONE.double()), + root, + data, + is_balanced, + is_permuted, + max_time, + ); + + bench_algorithm( + report, + metric, + queries, + neighbors, + &abd_clam::cakes::KnnBreadthFirst(k), + root, + data, + is_balanced, + is_permuted, + max_time, + ); + + bench_algorithm( + report, + metric, + queries, + neighbors, + &abd_clam::cakes::KnnDepthFirst(k), + root, + data, + is_balanced, + is_permuted, + max_time, + ); + } +} + +/// Run a single search algorithm on a cluster. +#[allow(clippy::too_many_arguments)] +fn bench_algorithm( + report: &mut CakesResults, + metric: &M, + queries: &[I], + neighbors: Option<&[Vec<(usize, T)>]>, + alg: &A, + root: &C, + data: &FlatVec, + is_balanced: bool, + is_permuted: bool, + max_time: Duration, +) where + I: Send + Sync, + T: Number, + M: ParCountingMetric, + C: ParCluster, + A: ParSearchAlgorithm>, + Me: Send + Sync, +{ + let cluster_name = { + let mut parts = Vec::with_capacity(3); + if is_permuted { + parts.push("Permuted"); + } + if is_balanced { + parts.push("Balanced"); + } + parts.push("Ball"); + parts.join("") + }; + + ftlog::info!("Running {} on {cluster_name} with {}...", alg.name(), data.name()); + let mut hits = Vec::with_capacity(100); + metric.reset_count(); + let start = std::time::Instant::now(); + while start.elapsed() < max_time { + hits.push(alg.par_batch_search(data, metric, root, queries)); + } + let total_time = start.elapsed().as_secs_f32(); + let distance_count = metric.count().as_f32() / (queries.len() * hits.len()).as_f32(); + + let n_runs = queries.len() * hits.len(); + let time = total_time / n_runs.as_f32(); + let throughput = n_runs.as_f32() / total_time; + ftlog::info!( + "With {cluster_name}, Algorithm {} achieved Throughput {throughput} q/s", + alg.name() + ); + + let last_hits = hits.last().unwrap_or_else(|| unreachable!("We ran it at least once")); + let output_sizes = last_hits.iter().map(Vec::len).collect::>(); + + #[allow(clippy::option_if_let_else, clippy::branches_sharing_code)] + let recalls = if let Some(_neighbors) = neighbors { + // let mut recalls = Vec::with_capacity(neighbors.len()); + // for (i, neighbors) in neighbors.iter().enumerate() { + // let mut recall = Vec::with_capacity(neighbors.len()); + // for (j, (idx, _)) in neighbors.iter().enumerate() { + // let mut count = 0; + // for hit in last_hits.iter() { + // if hit.iter().any(|(h, _)| *h == *idx) { + // count += 1; + // } + // } + // recall.push(count.as_f32() / last_hits.len().as_f32()); + // } + // recalls.push(recall); + // } + // recalls + vec![1.0; queries.len()] + } else { + vec![1.0; queries.len()] + }; + if let Some(radius) = alg.radius() { + report.append_radial_result( + &cluster_name, + alg.name(), + radius, + time, + throughput, + &output_sizes, + &recalls, + distance_count, + ); + } else if let Some(k) = alg.k() { + report.append_k_result( + &cluster_name, + alg.name(), + k, + time, + throughput, + &output_sizes, + &recalls, + distance_count, + ); + } +} diff --git a/benches/cakes/src/trees.rs b/benches/cakes/src/trees.rs new file mode 100644 index 000000000..9a129e8d7 --- /dev/null +++ b/benches/cakes/src/trees.rs @@ -0,0 +1,157 @@ +//! Building, saving and loading trees for benchmarking. + +use abd_clam::{ + cakes::PermutedBall, + cluster::{adapter::ParBallAdapter, BalancedBall, Csv, ParClusterIO, ParPartition}, + dataset::DatasetIO, + Ball, Dataset, FlatVec, +}; +use distances::Number; + +use crate::metric::ParCountingMetric; + +/// The seven output paths for a given dataset. +pub struct AllPaths { + /// The output directory. + pub out_dir: std::path::PathBuf, + /// The path to the `Ball` tree. + pub ball: std::path::PathBuf, + /// The path to the data. + pub data: std::path::PathBuf, + /// The path to the `BalancedBall` tree. + pub balanced_ball: std::path::PathBuf, + /// The path to the permuted `Ball` tree. + pub permuted_ball: std::path::PathBuf, + /// The path to the permuted `BalancedBall` tree. + pub permuted_balanced_ball: std::path::PathBuf, + /// The path to the permuted data. + pub permuted_data: std::path::PathBuf, + /// The path to the permuted balanced data. + pub permuted_balanced_data: std::path::PathBuf, +} + +impl AllPaths { + /// Creates a new `AllPaths` instance. + pub fn new>(out_dir: &P, data_name: &str) -> Self { + Self { + out_dir: out_dir.as_ref().to_path_buf(), + ball: out_dir.as_ref().join(format!("{data_name}.ball")), + data: out_dir.as_ref().join(format!("{data_name}.flat_vec")), + balanced_ball: out_dir.as_ref().join(format!("{data_name}.balanced_ball")), + permuted_ball: out_dir.as_ref().join(format!("{data_name}.permuted_ball")), + permuted_balanced_ball: out_dir.as_ref().join(format!("{data_name}.permuted_balanced_ball")), + permuted_data: out_dir.as_ref().join(format!("{data_name}-permuted.flat_vec")), + permuted_balanced_data: out_dir.as_ref().join(format!("{data_name}-permuted_balanced.flat_vec")), + } + } + + /// Whether all paths exist. + pub fn all_exist(&self, balanced: bool, permuted: bool) -> bool { + let mut base = self.ball.exists() && self.data.exists(); + + if balanced { + base = base && self.balanced_ball.exists(); + } + if permuted { + base = base && self.permuted_ball.exists() && self.permuted_data.exists(); + if balanced { + base = base && self.permuted_balanced_ball.exists() && self.permuted_balanced_data.exists(); + } + } + + base + } +} + +/// Builds all types of trees for the given dataset. +pub fn build_all( + out_dir: &P, + data: &FlatVec, + metric: &M, + seed: Option, + build_permuted: bool, + build_balanced: bool, + depth_stride: Option, +) -> Result<(), String> +where + P: AsRef, + I: Send + Sync + Clone + bitcode::Encode + bitcode::Decode, + T: Number + bitcode::Encode + bitcode::Decode, + M: ParCountingMetric, + Me: Send + Sync + Clone + bitcode::Encode + bitcode::Decode, +{ + ftlog::info!("Building all trees for {}...", data.name()); + let all_paths = AllPaths::new(out_dir, data.name()); + + if !all_paths.data.exists() { + ftlog::info!("Writing data to {:?}...", all_paths.data); + data.write_to(&all_paths.data)?; + } + + ftlog::info!("Building Ball..."); + metric.reset_count(); + let ball = depth_stride.map_or_else( + || Ball::par_new_tree(data, metric, &|_| true, seed), + |depth_stride| Ball::par_new_tree_iterative(data, metric, &|_| true, seed, depth_stride), + ); + ftlog::info!("Built Ball by calculating {} distances.", metric.count()); + + ftlog::info!("Writing Ball to {:?}...", all_paths.ball); + ball.par_write_to(&all_paths.ball)?; + let csv_path = out_dir.as_ref().join(format!("{}-ball.csv", data.name())); + ball.write_to_csv(&csv_path)?; + + if build_permuted { + ftlog::info!("Building Permuted Ball..."); + let (ball, data) = PermutedBall::par_from_ball_tree(ball, data.clone(), metric); + + ftlog::info!("Writing Permuted Ball to {:?}...", all_paths.permuted_ball); + ball.par_write_to(&all_paths.permuted_ball)?; + let csv_path = out_dir.as_ref().join(format!("{}-permuted-ball.csv", data.name())); + ball.write_to_csv(&csv_path)?; + + ftlog::info!("Writing Permuted data to {:?}...", all_paths.permuted_data); + data.write_to(&all_paths.permuted_data)?; + } + + if build_balanced { + ftlog::info!("Building Balanced Ball..."); + metric.reset_count(); + let ball = depth_stride + .map_or_else( + || BalancedBall::par_new_tree(data, metric, &|_| true, seed), + |depth_stride| BalancedBall::par_new_tree_iterative(data, metric, &|_| true, seed, depth_stride), + ) + .into_ball(); + ftlog::info!("Built Balanced Ball by calculating {} distances.", metric.count()); + + ftlog::info!("Writing Balanced Ball to {:?}...", all_paths.balanced_ball); + ball.par_write_to(&all_paths.balanced_ball)?; + let csv_path = out_dir.as_ref().join(format!("{}-balanced-ball.csv", data.name())); + ball.write_to_csv(&csv_path)?; + + if build_permuted { + ftlog::info!("Building Permuted Balanced Ball..."); + let (ball, data) = PermutedBall::par_from_ball_tree(ball, data.clone(), metric); + + ftlog::info!( + "Writing Permuted Balanced Ball to {:?}...", + all_paths.permuted_balanced_ball + ); + ball.par_write_to(&all_paths.permuted_balanced_ball)?; + let csv_path = out_dir + .as_ref() + .join(format!("{}-permuted-balanced-ball.csv", data.name())); + ball.write_to_csv(&csv_path)?; + + ftlog::info!( + "Writing Permuted Balanced data to {:?}...", + all_paths.permuted_balanced_data + ); + data.write_to(&all_paths.permuted_balanced_data)?; + } + } + ftlog::info!("Built all trees for {}.", data.name()); + + Ok(()) +} diff --git a/benches/cakes/src/workflow.rs b/benches/cakes/src/workflow.rs new file mode 100644 index 000000000..990cbfb6f --- /dev/null +++ b/benches/cakes/src/workflow.rs @@ -0,0 +1,284 @@ +//! Steps in the workflow of running CAKES benchmarks. + +use core::time::Duration; + +use abd_clam::{cakes::PermutedBall, cluster::ParClusterIO, dataset::DatasetIO, Ball, Cluster, Dataset, FlatVec}; +use bench_utils::{reports::CakesResults, Complex}; +use distances::Number; +use rand::prelude::*; + +use crate::{metric::ParCountingMetric, trees::AllPaths}; + +/// Run the workflow of the CAKES benchmarks on a fasta dataset. +#[allow(clippy::fn_params_excessive_bools, clippy::too_many_arguments)] +pub fn run_radio_ml, M: ParCountingMetric>, f64>>( + out_dir: &P, + data: &FlatVec>, usize>, + metric: &M, + queries: &[Vec>], + radial_fractions: &[f32], + ks: &[usize], + seed: Option, + max_time: Duration, + run_linear: bool, + balanced_trees: bool, + permuted_data: bool, + ranged_search: bool, + rebuild_trees: bool, +) -> Result<(), String> { + let all_paths = AllPaths::new(out_dir, data.name()); + if rebuild_trees || !all_paths.all_exist(balanced_trees, permuted_data) { + super::trees::build_all(out_dir, data, metric, seed, permuted_data, balanced_trees, None)?; + } + run::<_, _, _, usize>( + &all_paths, + metric, + queries, + None, + radial_fractions, + ks, + max_time, + run_linear, + balanced_trees, + permuted_data, + ranged_search, + ) +} + +/// Run the workflow of the CAKES benchmarks on a fasta dataset. +#[allow(clippy::fn_params_excessive_bools, clippy::too_many_arguments)] +pub fn run_fasta, M: ParCountingMetric>( + out_dir: &P, + data: &FlatVec, + metric: &M, + queries: &[String], + radial_fractions: &[f32], + ks: &[usize], + seed: Option, + max_time: Duration, + run_linear: bool, + balanced_trees: bool, + permuted_data: bool, + ranged_search: bool, + rebuild_trees: bool, +) -> Result<(), String> { + let all_paths = AllPaths::new(out_dir, data.name()); + if rebuild_trees || !all_paths.all_exist(balanced_trees, permuted_data) { + super::trees::build_all(out_dir, data, metric, seed, permuted_data, balanced_trees, Some(128))?; + } + run::<_, _, _, String>( + &all_paths, + metric, + queries, + None, + radial_fractions, + ks, + max_time, + run_linear, + balanced_trees, + permuted_data, + ranged_search, + ) +} + +/// Run the workflow of the CAKES benchmarks on a tabular dataset. +#[allow( + clippy::too_many_arguments, + clippy::fn_params_excessive_bools, + clippy::too_many_lines +)] +pub fn run_tabular( + out_dir: &P, + data_name: &str, + metric: &M, + queries: &[Vec], + num_queries: usize, + radial_fractions: &[f32], + ks: &[usize], + seed: Option, + max_time: Duration, + run_linear: bool, + balanced_trees: bool, + permuted_data: bool, + ranged_search: bool, + rebuild_trees: bool, +) -> Result<(), String> +where + P: AsRef, + M: ParCountingMetric, f32>, +{ + let data_path = out_dir.as_ref().join(format!("{data_name}.npy")); + let data = FlatVec::, usize>::read_npy(&data_path)?; + + let neighbors_path = out_dir.as_ref().join(format!("{data_name}-neighbors.npy")); + let distances_path = out_dir.as_ref().join(format!("{data_name}-distances.npy")); + let (queries, neighbors) = if neighbors_path.exists() && distances_path.exists() { + let neighbors = FlatVec::, usize>::read_npy(&neighbors_path)?.take_items(); + let neighbors = neighbors + .into_iter() + .map(|n| n.into_iter().map(Number::as_usize).collect::>()); + + let distances = FlatVec::, usize>::read_npy(&distances_path)?.take_items(); + + let neighbors = neighbors + .zip(distances) + .map(|(n, d)| n.into_iter().zip(d).collect::>()) + .collect::>(); + + let mut queries = queries.iter().cloned().zip(neighbors).collect::>(); + + let mut rng = rand::thread_rng(); + queries.shuffle(&mut rng); + let _ = queries.split_off(num_queries); + + let (queries, neighbors): (Vec<_>, Vec<_>) = queries.into_iter().unzip(); + (queries, Some(neighbors)) + } else { + let mut rng = rand::thread_rng(); + let mut queries = queries.to_vec(); + queries.shuffle(&mut rng); + let _ = queries.split_off(num_queries); + (queries, None) + }; + let neighbors = neighbors.as_deref(); + + let all_paths = AllPaths::new(out_dir, data.name()); + if rebuild_trees || !all_paths.all_exist(balanced_trees, permuted_data) { + super::trees::build_all(out_dir, &data, metric, seed, permuted_data, balanced_trees, None)?; + } + run::<_, _, _, usize>( + &all_paths, + metric, + &queries, + neighbors, + radial_fractions, + ks, + max_time, + run_linear, + balanced_trees, + permuted_data, + ranged_search, + ) +} + +/// Run the full workflow of the CAKES benchmarks on a dataset. +#[allow(clippy::fn_params_excessive_bools, clippy::too_many_arguments)] +fn run( + all_paths: &AllPaths, + metric: &M, + queries: &[I], + neighbors: Option<&[Vec<(usize, T)>]>, + radial_fractions: &[f32], + ks: &[usize], + max_time: Duration, + run_linear: bool, + balanced_trees: bool, + permuted_data: bool, + ranged_search: bool, +) -> Result<(), String> +where + I: Send + Sync + Clone + bitcode::Encode + bitcode::Decode, + T: Number + bitcode::Encode + bitcode::Decode, + M: ParCountingMetric, + Me: Send + Sync + Clone + bitcode::Encode + bitcode::Decode, +{ + ftlog::info!("Reading Ball from {:?}...", all_paths.ball); + let ball = Ball::::par_read_from(&all_paths.ball)?; + let radii = radial_fractions + .iter() + .map(|&f| f * ball.radius().as_f32()) + .map(T::from) + .collect::>(); + + ftlog::info!("Reading data from {:?}...", all_paths.data); + let data = FlatVec::::read_from(&all_paths.data)?; + + let (min_dim, max_dim) = data.dimensionality_hint(); + let mut report = CakesResults::new( + data.name(), + data.cardinality(), + max_dim.unwrap_or(min_dim), + metric.name(), + ); + + ftlog::info!("Running search algorithms on Ball..."); + crate::search::bench_all_algs( + &mut report, + metric, + queries, + neighbors, + &ball, + &data, + false, + false, + max_time, + ks, + &radii, + run_linear, + ranged_search, + ); + + if permuted_data { + let ball = PermutedBall::>::par_read_from(&all_paths.permuted_ball)?; + let data = FlatVec::::read_from(&all_paths.permuted_data)?; + ftlog::info!("Running search algorithms on PermutedBall..."); + crate::search::bench_all_algs( + &mut report, + metric, + queries, + None, + &ball, + &data, + false, + true, + max_time, + ks, + &radii, + run_linear, + ranged_search, + ); + } + + if balanced_trees { + let ball = Ball::::par_read_from(&all_paths.balanced_ball)?; + ftlog::info!("Running search algorithms on BalancedBall..."); + crate::search::bench_all_algs( + &mut report, + metric, + queries, + neighbors, + &ball, + &data, + true, + false, + max_time, + ks, + &radii, + run_linear, + ranged_search, + ); + + if permuted_data { + let ball = PermutedBall::>::par_read_from(&all_paths.permuted_balanced_ball)?; + let data = FlatVec::::read_from(&all_paths.permuted_balanced_data)?; + ftlog::info!("Running search algorithms on PermutedBalancedBall..."); + crate::search::bench_all_algs( + &mut report, + metric, + queries, + None, + &ball, + &data, + true, + true, + max_time, + ks, + &radii, + run_linear, + ranged_search, + ); + } + } + + report.write_to_csv(&all_paths.out_dir) +} diff --git a/benches/py-cakes/.python-version b/benches/py-cakes/.python-version new file mode 100644 index 000000000..8531a3b7e --- /dev/null +++ b/benches/py-cakes/.python-version @@ -0,0 +1 @@ +3.12.2 diff --git a/benches/py-cakes/README.md b/benches/py-cakes/README.md new file mode 100644 index 000000000..71ad0d985 --- /dev/null +++ b/benches/py-cakes/README.md @@ -0,0 +1,29 @@ +# Plotting results of the CAKES benchmarks + +This package provides a CLI for plotting the results of the CAKES benchmarks. + +You mush first run any benchmarks you want to plot using the `bench-cakes` crate we provide at `../cakes`. +See the associated README for more information. + +## Usage + +Create a virtual environment with Python 3.9 or later and activate it: + +```bash +python -m venv venv +source venv/bin/activate +``` + +Install the package with the following command: + +```bash +python -m pip install -e benches/py-cakes +``` + +Let's say you ran benchmarks for CAKES and saved results in a directory `../data/output`. +You now want to generate the plots and save them in a directory `../data/summary`. +You can do this with the following command: + +```bash +python -m py_cakes summarize-rust --inp-dir ../data/output --out-dir ../data/summary +``` diff --git a/benches/py-cakes/pyproject.toml b/benches/py-cakes/pyproject.toml new file mode 100644 index 000000000..e924d4ea4 --- /dev/null +++ b/benches/py-cakes/pyproject.toml @@ -0,0 +1,34 @@ +[project] +name = "py-cakes" +version = "0.1.0" +description = "Benchmarks for search algorithms for the CAKES paper." +authors = [ + { name = "Najib Ishaq", email = "najib_ishaq@zoho.com" } +] +dependencies = [ + "typer>=0.15.1", + "pandas>=2.2.3", + "matplotlib>=3.9.4", + "numpy>=2.0.2", +] +readme = "README.md" +requires-python = ">= 3.9" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = [] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/py_cakes"] + +[tool.ruff] +lint.extend-ignore = [ + "G004", # Use of f-string in logging statements +] diff --git a/benches/py-cakes/src/py_cakes/__init__.py b/benches/py-cakes/src/py_cakes/__init__.py new file mode 100644 index 000000000..291361fb7 --- /dev/null +++ b/benches/py-cakes/src/py_cakes/__init__.py @@ -0,0 +1,12 @@ +"""Benchmarking search algorithms for the CAKES paper.""" + +from .summarize_rust import summarize_rust +from . import utils + + +def hello(suffix: str) -> str: + """Return a greeting message.""" + return f"Working on {suffix}..." + + +__all__ = ["utils", "hello", "summarize_rust"] diff --git a/benches/py-cakes/src/py_cakes/__main__.py b/benches/py-cakes/src/py_cakes/__main__.py new file mode 100644 index 000000000..86994709d --- /dev/null +++ b/benches/py-cakes/src/py_cakes/__main__.py @@ -0,0 +1,104 @@ +"""CLI for running the benchmark suite.""" + +import logging +import pathlib + +import typer + +import py_cakes + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = py_cakes.utils.configure_logger("CakesBenchmarks", "INFO") + +app = typer.Typer() + + +@app.command() +def summarize_rust( + inp_dir: pathlib.Path = typer.Option( # noqa: B008 + ..., + "--inp-dir", + "-i", + help="Path to the directory containing the input files.", + exists=True, + file_okay=False, + dir_okay=True, + readable=True, + resolve_path=True, + ), + out_dir: pathlib.Path = typer.Option( # noqa: B008 + ..., + "--out-dir", + "-o", + help="Path to the directory to store the output files.", + exists=True, + file_okay=False, + dir_okay=True, + writable=True, + resolve_path=True, + ), +) -> None: + """Summarize the results from the Rust implementation. + + The input directory should contain the output files generated by the Rust + implementation of the CAKES search algorithm. The output directory will + store the collected results in a CSV file. + """ + logger.info( + "Collecting the results of the Rust implementation of the CAKES search algorithm." + ) + logger.info(f"Input directory: {inp_dir}") + logger.info(f"Output directory: {out_dir}") + logger.info("") + + py_cakes.summarize_rust(inp_dir, out_dir) + + logger.info("Done.") + + +@app.command() +def run_faiss( + inp_dir: pathlib.Path = typer.Option( # noqa: B008 + ..., + "--inp-dir", + "-i", + help="Path to the directory containing the input data.", + exists=True, + file_okay=False, + dir_okay=True, + readable=True, + resolve_path=True, + ), + out_dir: pathlib.Path = typer.Option( # noqa: B008 + ..., + "--out-dir", + "-o", + help="Path to the directory to store the output files.", + exists=True, + file_okay=False, + dir_okay=True, + writable=True, + resolve_path=True, + ), +) -> None: + """Run search algorithms from the FAISS library. + + The input directory should contain the input files. The output directory + will store the output files generated by the FAISS implementation of the + CAKES search algorithm. + """ + logger.info("Running the FAISS implementation of the CAKES search algorithm.") + logger.info(f"Input directory: {inp_dir}") + logger.info(f"Output directory: {out_dir}") + logger.info("") + + logger.info(py_cakes.hello("running FAISS algorithms")) + + logger.info("Done.") + + +if __name__ == "__main__": + app() diff --git a/benches/py-cakes/src/py_cakes/competitors/annoy.csv b/benches/py-cakes/src/py_cakes/competitors/annoy.csv new file mode 100644 index 000000000..aee9c9c8d --- /dev/null +++ b/benches/py-cakes/src/py_cakes/competitors/annoy.csv @@ -0,0 +1,32 @@ +dataset,cardinality,k,throughput,recall +fashion-mnist,60000,10,2190,0.950 +fashion-mnist,120000,10,2120,0.927 +fashion-mnist,240000,10,2040,0.898 +fashion-mnist,480000,10,1930,0.857 +fashion-mnist,960000,10,1840,0.862 +fashion-mnist,1920000,10,1850,0.775 +fashion-mnist,3840000,10,1780,0.677 +fashion-mnist,7680000,10,1660,0.538 +fashion-mnist,15360000,10,1600,0.592 +fashion-mnist,30720000,10,1830,0.581 +glove-25,1183514,10,2830,0.835 +glove-25,2367028,10,2700,0.832 +glove-25,4734056,10,2610,0.839 +glove-25,9468112,10,2510,0.834 +glove-25,18936224,10,2230,0.885 +glove-25,37872448,10,2010,0.764 +glove-25,75744896,10,1830,0.631 +sift,1000000,10,3980,0.686 +sift,2000000,10,3800,0.614 +sift,4000000,10,3690,0.637 +sift,8000000,10,3580,0.710 +sift,16000000,10,3500,0.690 +sift,32000000,10,3440,0.639 +sift,64000000,10,3390,0.678 +sift,128000000,10,3360,0.643 +random,1000000,10,4280,0.028 +random,2000000,10,4040,0.021 +random,4000000,10,3640,0.014 +random,8000000,10,3370,0.013 +random,16000000,10,3170,0.006 +random,32000000,10,3010,0.007 diff --git a/benches/py-cakes/src/py_cakes/competitors/faiss-ivf.csv b/benches/py-cakes/src/py_cakes/competitors/faiss-ivf.csv new file mode 100644 index 000000000..732d95141 --- /dev/null +++ b/benches/py-cakes/src/py_cakes/competitors/faiss-ivf.csv @@ -0,0 +1,34 @@ +dataset,cardinality,k,throughput +fashion-mnist,60000,10,2010 +fashion-mnist,120000,10,939 +fashion-mnist,240000,10,461 +fashion-mnist,480000,10,226 +fashion-mnist,960000,10,117 +fashion-mnist,1920000,10,59.1 +fashion-mnist,3840000,10,26.1 +fashion-mnist,7680000,10,13.3 +fashion-mnist,15360000,10,6.65 +fashion-mnist,30720000,10,3.56 +glove-25,1183514,10,2380 +glove-25,2367028,10,1190 +glove-25,4734056,10,619 +glove-25,9468112,10,303 +glove-25,18936224,10,151 +glove-25,37872448,10,74.0 +glove-25,75744896,10,37.7 +glove-25,151489792,10,19.0 +glove-25,302979584,10,9.47 +sift,1000000,10,698 +sift,2000000,10,330 +sift,4000000,10,165 +sift,8000000,10,77.2 +sift,16000000,10,39.8 +sift,32000000,10,20.9 +sift,64000000,10,8.87 +sift,128000000,10,4.78 +random,1000000,10,734 +random,2000000,10,358 +random,4000000,10,190 +random,8000000,10,88.4 +random,16000000,10,43.6 +random,32000000,10,17.2 diff --git a/benches/py-cakes/src/py_cakes/competitors/hnsw.csv b/benches/py-cakes/src/py_cakes/competitors/hnsw.csv new file mode 100644 index 000000000..87dcf3eff --- /dev/null +++ b/benches/py-cakes/src/py_cakes/competitors/hnsw.csv @@ -0,0 +1,26 @@ +dataset,cardinality,k,throughput,recall +fashion-mnist,60000,10,13300,0.954 +fashion-mnist,120000,10,13800,0.803 +fashion-mnist,240000,10,16600,0.681 +fashion-mnist,480000,10,16800,0.525 +fashion-mnist,960000,10,18700,0.494 +fashion-mnist,1920000,10,15600,0.542 +fashion-mnist,3840000,10,15000,0.378 +fashion-mnist,7680000,10,14900,0.357 +glove-25,1183514,10,22800,0.801 +glove-25,2367028,10,23800,0.607 +glove-25,4734056,10,25000,0.443 +glove-25,9468112,10,27800,0.294 +glove-25,18936224,10,31100,0.213 +glove-25,37872448,10,32400,0.178 +sift,1000000,10,19300,0.686 +sift,2000000,10,20300,0.552 +sift,4000000,10,21800,0.394 +sift,8000000,10,24800,0.298 +sift,16000000,10,26800,0.210 +sift,32000000,10,27500,0.193 +random,1000000,10,11700,0.060 +random,2000000,10,10100,0.048 +random,4000000,10,9120,0.031 +random,8000000,10,8350,0.022 +random,16000000,10,8250,0.008 diff --git a/benches/py-cakes/src/py_cakes/summarize_rust.py b/benches/py-cakes/src/py_cakes/summarize_rust.py new file mode 100644 index 000000000..7a1fea760 --- /dev/null +++ b/benches/py-cakes/src/py_cakes/summarize_rust.py @@ -0,0 +1,657 @@ +"""Summarize the results from the Rust implementation.""" + +import math +import pathlib + +import pandas +import matplotlib.pyplot as plt +import numpy + +from . import utils + +logger = utils.configure_logger(__name__, "INFO") + + +ALG_MARKERS_COLORS = { + "KnnLinear": ("+", "tab:blue"), + "KnnRepeatedRnn": ("s", "tab:orange"), + "KnnDepthFirst": ("x", "tab:green"), + "KnnBreadthFirst": ("o", "tab:red"), + "RnnLinear": ("v", "tab:purple"), + "RnnClustered": ("<", "tab:brown"), +} + +COMP_COLOR_MARKERS = { + "hnsw": ("*", "tab:pink"), + "annoy": ("p", "tab:olive"), + "faiss-ivf": ("d", "tab:cyan"), +} + + +CLUSTER_MARKERS_COLORS = { + "Ball": ("s", "tab:blue"), + "BalancedBall": ("o", "tab:green"), + "PermutedBall": ("x", "tab:orange"), + "PermutedBalancedBall": ("+", "tab:red"), +} + + +def summarize_rust( + inp_dir: pathlib.Path, + out_dir: pathlib.Path, +) -> None: + """Summarize the results from the Rust implementation. + + The input directory should contain the output files generated by the Rust + implementation of the CAKES search algorithm. The output directory will + store the collected results in a CSV file. + """ + + datasets: dict[str, Dataset] = {} + ball_csv_lists: dict[str, list[pathlib.Path]] = {} + + # Collect all csv files in the input directory + csv_files = sorted(inp_dir.glob("*.csv")) + ball_csv_paths = list(filter(lambda f: f.name.endswith("ball.csv"), csv_files)) + bench_csv_files = filter(lambda f: not f.name.endswith("ball.csv"), csv_files) + for f in bench_csv_files: + logger.info(f"Processing {f.name}") + # The names are in the format "___.csv" + dataset_name, cardinality_str, dimensionality_str, metric = f.stem.split("_") + cardinality = int(cardinality_str) + dimensionality = int(dimensionality_str) + + # If the `dataset` ends in a number, it has been synthetically augmented + # for scalability testing. + is_augmented = dataset_name[-1].isdigit() + if is_augmented: + parts = dataset_name.split("-") + multiplier = 2 ** int(parts[-1]) + dataset_name = "-".join(parts[:-1]) + else: + multiplier = 1 + + if dataset_name not in datasets: + datasets[dataset_name] = Dataset( + name=dataset_name, + cardinality=cardinality, + dimensionality=dimensionality, + metric=metric, + multiplier=multiplier, + ) + datasets[dataset_name].add_csv_path(cardinality, f) + ball_csv_lists[dataset_name] = list( + filter(lambda x: x.stem.startswith(dataset_name), ball_csv_paths) + ) + + for dataset in datasets.values(): + dataset.summarize_results(out_dir) + + for dataset_name, paths in ball_csv_lists.items(): + # if "fashion" not in dataset_name: + # continue + logger.info(f"Making plots for {dataset_name}") + paths = list(filter(lambda x: "balanced" not in x.stem, paths)) + paths = list(filter(lambda x: "permuted" not in x.stem, paths)) + keyed_paths = [ + (int(path.stem[len(dataset_name) + 1 :].split("-")[0]), path) + for path in paths + ] + keyed_paths.sort(key=lambda x: x[0]) + paths = [path for _, path in keyed_paths] + + all_props = [ + ("lfd", "LFD"), + ("radius", "Radius"), + ("fractal_density", "Fractal Density"), + ] + for prop, alias in all_props[1:]: + plot_prop_percentiles( + out_dir=out_dir, + dataset=dataset_name, + ball_csv_path=paths[0], + prop=prop, + prop_alias=alias, + ) + + logger.info("Done.") + + +class Dataset: + def __init__( + self, + *, + name: str, + cardinality: int, + dimensionality: int, + metric: str, + multiplier: int, + ): + self.name = name + self.base_cardinality = cardinality // multiplier + self.dimensionality = dimensionality + self.metric = metric + self.__csv_paths: dict[int, pathlib.Path] = {} + + def __repr__(self): + parts = [ + f"name={self.name}", + f"cardinality={self.base_cardinality}", + f"dimensionality={self.dimensionality}", + f"metric={self.metric}", + ] + return f"Dataset({', '.join(parts)})" + + def __hash__(self): + return hash(self.name) + + def add_csv_path(self, cardinality: int, path: pathlib.Path) -> None: + self.__csv_paths[cardinality] = path + + @property + def csv_paths(self) -> list[tuple[int, pathlib.Path]]: + paths = list(self.__csv_paths.items()) + paths.sort(key=lambda x: x[0]) + return paths + + def summarize_results(self, out_dir: pathlib.Path) -> None: + csv_paths = self.csv_paths + + if len(csv_paths) == 0: + logger.warning(f" No CSV files found for {self.name}") + return + + columns = list(pandas.read_csv(csv_paths[0][1]).columns) + if "mean_distance_computations" not in columns: + columns.append("mean_distance_computations") + + columns.insert(0, "cardinality") + out_df = pandas.DataFrame(columns=columns) + for i, (cardinality, path) in enumerate(csv_paths): + logger.info(f" Reading {path.name}") + df = pandas.read_csv(path) + + if i > 4: + # Remove any rows whose "algorithm" column is "KnnLinear" + df = df[df["algorithm"] != "KnnLinear"] + df.reset_index(drop=True, inplace=True) + + # insert a column with the multiplier + df.insert(0, "cardinality", cardinality) + + # If there is no "mean_distance_computations" column, add it with a value of 0 + if "mean_distance_computations" not in df.columns: + df.insert(0, "mean_distance_computations", 0) + + # join the dataframes + out_df = pandas.concat([out_df, df], ignore_index=True) + + # Change the "throughput" column to floats + out_df["throughput"] = out_df["throughput"].astype(float) + # Drop all rows where the k is a nan + out_df = out_df.dropna(subset=["k"]) + # Change the "k" column to integers + out_df["k"] = out_df["k"].astype(int) + + # Sort the dataframe by the "cardinality", "cluster" and "algorithm" columns + out_df.sort_values( + by=["k", "cardinality", "cluster", "algorithm"], inplace=True + ) + # Drop the "radius" column + out_df.drop(columns=["radius"], inplace=True) + + # Reset the index + out_df.reset_index(drop=True, inplace=True) + # Save the dataframe to a CSV file + out_df.to_csv(out_dir / f"{self.name}_{self.metric}.csv", index=False) + + # Get the subset of the dataframe where the "algorithm" column is "KnnLinear" + # And "k" is 10. + knn_linear_10 = out_df[ + (out_df["algorithm"] == "KnnLinear") & (out_df["k"] == 10) + ] + # Set k to "100" in this subset + knn_linear_10["k"] = 100 + # Append this subset to the dataframe + out_df = pandas.concat([out_df, knn_linear_10], ignore_index=True) + + # Sort the dataframe by the "cluster" and "algorithm" columns + out_df.sort_values(by=["cluster", "algorithm"], inplace=True) + # Reset the index + out_df.reset_index(drop=True, inplace=True) + + # Group by the "cluster" column and plot the throughput + grouped = out_df.groupby(["cluster", "k"]) + for (cluster, k), group in grouped: + plot_throughput( + out_dir=out_dir, + dataset=self.name, + cluster=cluster, + k=k, + group=group, + ) + + # If any "mean_distance_computations" value is non-zero, group by the + # "algorithm" column and plot the distance counts + if (out_df["mean_distance_computations"] > 0).any(): + grouped = out_df.groupby(["algorithm", "k"]) + for (algorithm, k), group in grouped: + plot_distance_counts( + out_dir=out_dir, + dataset=self.name, + algorithm=algorithm, + k=k, + group=group, + ) + + +def plot_throughput( + *, + out_dir: pathlib.Path, + dataset: str, + cluster: str, + k: int, + group: pandas.DataFrame, +) -> None: + """Plot the throughput of all algorithms.""" + title = f"{dataset} - {cluster} - {k=}" + logger.info(f" Plotting Throughput {title}") + + # Create a figure and axis + fig: plt.Figure + ax: plt.Axes + m = 0.8 + fig, ax = plt.subplots(figsize=(6 * m, 6 * m)) + + # Group by the "algorithm" column + for alg, g in group.groupby("algorithm"): + # Sort the rows by the "cardinality" column + g.sort_values("cardinality", inplace=True) + # Plot the "cardinality" column on the x-axis against the "throughput" + # column on the y-axis + marker, color = ALG_MARKERS_COLORS[alg] + ax.plot( + g["cardinality"], + g["throughput"], + label=f"{alg}", + marker=marker, + color=color, + ) + + min_throughput = float(group["throughput"].min()) + max_throughput = float(group["throughput"].max()) + + # If the dataset is one of the ann-benchmarks datasets, we need to add + # results of HNSW, ANNOY, and FAISS-IVF. + if not ("silva" in dataset or "radio" in dataset): + comp_dir = pathlib.Path(__file__).parent / "competitors" + assert comp_dir.exists(), f"{comp_dir} does not exist" + assert comp_dir.is_dir(), f"{comp_dir} is not a directory" + comp_paths = [ + comp_dir / f"{name}.csv" for name in ["hnsw", "annoy", "faiss-ivf"] + ] + assert all(p.exists() for p in comp_paths), f"Missing files: {comp_paths}" + for path in comp_paths: + comp_df = pandas.read_csv(path) + # filter for the current dataset and k + comp_df = comp_df[(comp_df["dataset"] == dataset) & (comp_df["k"] == k)] + alg = path.stem + marker, color = COMP_COLOR_MARKERS[alg] + ax.plot( + comp_df["cardinality"], + comp_df["throughput"], + label=alg.upper(), + marker=marker, + color=color, + ) + + max_throughput = max(max_throughput, comp_df["throughput"].max()) + min_throughput = min(min_throughput, comp_df["throughput"].min()) + + # if "recall" is in the columns, add it as a floating annotation + # above each point, but with a smaller font size + if "recall" in comp_df.columns: + for _, row in comp_df.iterrows(): + ax.annotate( + f"{row['recall']:.2f}", + (row["cardinality"], row["throughput"]), + textcoords="offset points", + xytext=(0, 6), + ha="center", + fontsize=8, + ) + + # Set a good y-axis limit + y_min = (10 ** math.floor(math.log10(min_throughput))) * 0.9 + y_max = (10 ** math.ceil(math.log10(max_throughput))) * 1.1 + ax.set_ylim(y_min, y_max) + + # Set the title and labels + # ax.set_title(title) + ax.set_xlabel("Cardinality") + ax.set_ylabel("Queries per second") + + # Make the top and right spines invisible + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + # Set both axes to a logarithmic scale + ax.set_xscale("log") + ax.set_yscale("log") + + # Shrink y-axis by 20% + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width, box.height * 0.8]) + + # Put a legend above the plot + ax.legend( + loc="upper center", + bbox_to_anchor=(0.49, 1.45), + fancybox=True, + ncol=3, + ) + + # # Tighten the layout + # plt.tight_layout() + + # Save the figure + (out_dir / "plots").mkdir(parents=False, exist_ok=True) + fig.savefig(out_dir / "plots" / f"{dataset}_{cluster}_{k}_throughput.png", dpi=300) + + # Close the figure + plt.close(fig) + + +def plot_distance_counts( + *, + out_dir: pathlib.Path, + dataset: str, + algorithm: str, + k: int, + group: pandas.DataFrame, +) -> None: + """Plot the number of distance computations of all clusters.""" + title = f"{dataset} - {algorithm} - {k=}" + logger.info(f" Plotting Distance Counts {title}") + + # Create a figure and axis + fig: plt.Figure + ax: plt.Axes + m = 0.8 + fig, ax = plt.subplots(figsize=(6 * m, 6 * m)) + + # Group by the "cluster" column + for cluster, g in group.groupby("cluster"): + # Sort the rows by the "cardinality" column + g.sort_values("cardinality", inplace=True) + # Plot the "cardinality" column on the x-axis against the + # "mean_distance_computations" column on the y-axis + marker, color = CLUSTER_MARKERS_COLORS[cluster] + ax.plot( + g["cardinality"], + g["mean_distance_computations"], + label=f"{cluster}", + marker=marker, + color=color, + ) + + # Set the title and labels + ax.set_xlabel("Cardinality") + ax.set_ylabel("# Distance Computations") + + # Make the top and right spines invisible + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + # Set both axes to a logarithmic scale + ax.set_xscale("log") + ax.set_yscale("log") + + # Shrink y-axis by 20% + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width, box.height * 0.8]) + + # Put a legend above the plot + ax.legend( + loc="upper center", + bbox_to_anchor=(0.49, 1.45), + fancybox=True, + ncol=2, + ) + + # # Tighten the layout + # plt.tight_layout() + + # Save the figure + out_path = out_dir / "distance_counts" / f"{dataset}_{algorithm}_{k}_counts.png" + out_path.parent.mkdir(parents=False, exist_ok=True) + fig.savefig(out_path, dpi=300) + + # Close the figure + plt.close(fig) + + # Plot the throughput of all algorithms. + fig, ax = plt.subplots(figsize=(6 * m, 6 * m)) + + # Group by the "cluster" column + for cluster, g in group.groupby("cluster"): + # Sort the rows by the "cardinality" column + g.sort_values("cardinality", inplace=True) + # Plot the "cardinality" column on the x-axis against the + # "mean_distance_computations" column on the y-axis + marker, color = CLUSTER_MARKERS_COLORS[cluster] + ax.plot( + g["cardinality"], + g["throughput"], + label=f"{cluster}", + marker=marker, + color=color, + ) + + # Set the title and labels + ax.set_xlabel("Cardinality") + ax.set_ylabel("Queries Per Second") + + # Make the top and right spines invisible + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + # Set both axes to a logarithmic scale + ax.set_xscale("log") + ax.set_yscale("log") + + # Shrink y-axis by 20% + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width, box.height * 0.8]) + + # Put a legend above the plot + ax.legend( + loc="upper center", + bbox_to_anchor=(0.49, 1.45), + fancybox=True, + ncol=2, + ) + + # # Tighten the layout + # plt.tight_layout() + + # Save the figure + out_path = out_path.parent / f"{dataset}_{algorithm}_{k}_throughput.png" + fig.savefig(out_path, dpi=300) + + # Close the figure + plt.close(fig) + + +def plot_prop_percentiles( + *, + out_dir: pathlib.Path, + dataset: str, + ball_csv_path: pathlib.Path, + prop: str, + prop_alias: str, +): + logger.info( + f" Plotting {prop_alias} percentiles {dataset = } with ball = {ball_csv_path.name}" + ) + + col_tuples = [ + ("minimum", 0, "tab:pink", "dotted", 0.2 * 2), + (" 5th percentile", 5, "tab:brown", "dashed", 0.3 * 2), + ("25th percentile", 25, "tab:purple", "solid", 0.4 * 2), + ("median", 50, "tab:red", "solid", 0.5 * 2), + ("75th percentile", 75, "tab:green", "solid", 0.4 * 2), + ("95th percentile", 95, "tab:orange", "dashed", 0.3 * 2), + ("maximum", 100, "tab:blue", "dotted", 0.2 * 2), + ] + columns = [t[0] for t in col_tuples] + percentiles = [t[1] for t in col_tuples] + colors = [t[2] for t in col_tuples] + styles = [t[3] for t in col_tuples] + widths = [t[4] for t in col_tuples] + shades_alphas = [ + ("blue", 0.05 * 2), + ("orange", 0.1 * 2), + ("green", 0.2 * 2), + ("green", 0.2 * 2), + ("orange", 0.1 * 2), + ("blue", 0.05 * 2), + ] + + # Read the CSV file and manipulate the data as needed for specific properties + inp_df = pandas.read_csv(ball_csv_path) + if prop in ["radius", "fractal_density"]: + # Find the largest radius in the dataset + max_radius = inp_df["radius"].max() + # Divide all radii by the largest radius + inp_df["radius"] /= max_radius + + if prop == "fractal_density": + # If the radius is smaller than f32::epsilon, remove the row + inp_df = inp_df[inp_df["radius"] > 1e-6] + + # Fractal density is defined as cardinality / (radius) ^ lfd + # Add a new column for fractal density + inp_df["fractal_density"] = inp_df["cardinality"] / ( + inp_df["radius"] ** inp_df["lfd"] + ) + + # We will make a new dataframe in which each row is one depth level in the + # ball tree and the columns are the percentiles of the property values for + # that depth level. + prop_df = pandas.DataFrame(columns=columns) + + # Group by the "depth" column + for depth, group in inp_df.groupby("depth"): + if depth > 100: + continue + + prop_values = [] + # Each value must be repeated a number of times equal to the cardinality + # of the corresponding cluster + for _, row in group.iterrows(): + prop_values.extend([row[prop]] * row["cardinality"]) + + # Calculate the percentile values + percentile_values = list(map(float, numpy.percentile(prop_values, percentiles))) + # Add them to the dataframe + prop_df.loc[depth] = percentile_values + + logger.info(f" Created {prop} dataframe with {prop_df.shape[0]} rows") + logger.info(f" {prop_df.head(10)}") + + # Create a figure and axis + fig: plt.Figure + ax: plt.Axes + m = 0.8 + fig, ax = plt.subplots(figsize=(6 * m, 6 * m)) + + y_min = prop_df.max().max() + y_max = prop_df.min().min() + + # Plot the percentiles + x = prop_df.index + for i, col in enumerate(columns): + y = prop_df[col] + y_min = min(y_min, y.min()) + y_max = max(y_max, y.max()) + ax.plot( + x, y, label=col, color=colors[i], linestyle=styles[i], linewidth=widths[i] + ) + + # Shade the are between each pair of percentiles + for y_lower, y_upper, (color, alpha) in zip( + columns[:-1], columns[1:], shades_alphas + ): + ax.fill_between( + prop_df.index, + prop_df[y_lower], + prop_df[y_upper], + color=color, + alpha=alpha, + ) + + # Set the title and labels + ax.set_xlabel("Depth") + ax.set_ylabel(prop_alias) + + if prop == "lfd": + max_lfd = 21 if "random" in dataset else 13 + # Set the y-axis limit to (-1, max_lfd) + ax.set_ylim(-1, max_lfd) + # Set the y-ticks to [0, 2, 4, ..., max_lfd] + y_ticks = numpy.arange(0, max_lfd, 2) + ax.set_yticks(y_ticks) + # Add a horizontal line at each y-tick + for y in y_ticks: + ax.axhline(y, color="gray", linestyle="solid", linewidth=0.1) + elif prop == "radius": + if "silva" in dataset or "radio" in dataset: + # Set the y-axis to a logarithmic scale + ax.set_yscale("log") + else: + # Add a horizontal line at each of [0.0, 0.2, ..., 1.0] + y_ticks = [r / 10 for r in range(0, 11, 2)] + ax.set_yticks(y_ticks) + for y in y_ticks: + ax.axhline(y, color="gray", linestyle="solid", linewidth=0.1) + elif prop == "fractal_density": + # Set the y-axis to a logarithmic scale + ax.set_yscale("log") + + # Set a good y-axis limit + y_min = (10 ** math.floor(math.log10(y_min))) * 0.9 + y_max = (10 ** math.ceil(math.log10(y_max))) * 1.1 + ax.set_ylim(y_min, y_max) + + # Make the top and right spines invisible + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + # Shrink y-axis by 20% + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width, box.height * 0.8]) + + # Put a legend above the plot + handles, labels = plt.gca().get_legend_handles_labels() + order = [0, 3, 6, 1, 5, 2, 4] + ax.legend( + [handles[idx] for idx in order], + [labels[idx] for idx in order], + loc="upper center", + bbox_to_anchor=(0.49, 1.45), + fancybox=True, + ncol=3, + ) + + # # Tighten the layout + # plt.tight_layout() + + # Save the figure + out_path = out_dir / prop / f"{dataset}.png" + out_path.parent.mkdir(parents=False, exist_ok=True) + logger.info(f" Saving to {out_path}") + fig.savefig(out_path, dpi=300) + + # Close the figure + plt.close(fig) diff --git a/benches/py-cakes/src/py_cakes/utils.py b/benches/py-cakes/src/py_cakes/utils.py new file mode 100644 index 000000000..e916992ad --- /dev/null +++ b/benches/py-cakes/src/py_cakes/utils.py @@ -0,0 +1,10 @@ +"""Helpers for the package.""" + +import logging + + +def configure_logger(name: str, level: str) -> logging.Logger: + """Configure a logger with the given name and level.""" + logger = logging.getLogger(name) + logger.setLevel(level) + return logger diff --git a/benches/utils/Cargo.toml b/benches/utils/Cargo.toml new file mode 100644 index 000000000..3ecc00071 --- /dev/null +++ b/benches/utils/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "bench-utils" +version = "0.1.0" +edition = "2021" + +[dependencies] +clap = { version = "4.5.16", features = ["derive"] } +hdf5 = { workspace = true } +ndarray = { workspace = true } +ftlog = { workspace = true } +symagen = { workspace = true } +distances = { workspace = true } +abd-clam = { workspace = true, features = ["all"]} +bitcode = { workspace = true } +csv = { workspace = true } +stringzilla = "3.9.5" +bio = { workspace = true } +rand = { workspace = true } +rayon = { workspace = true } diff --git a/benches/utils/src/ann_benchmarks/mod.rs b/benches/utils/src/ann_benchmarks/mod.rs new file mode 100644 index 000000000..73f62ad04 --- /dev/null +++ b/benches/utils/src/ann_benchmarks/mod.rs @@ -0,0 +1,57 @@ +//! Helpers for running benchmarks on ANN datasets. + +use distances::Number; +use rayon::prelude::*; + +mod reader; + +pub use reader::read; + +/// A helper for storing training and query data for `ann-benchmarks`'s datasets +/// along with the ground truth nearest neighbors and distances. +pub struct AnnDataset { + /// The data to use for clustering. + pub train: Vec>, + /// The queries to use for search. + pub queries: Vec>, + /// The true neighbors of each query, given as a tuple of: + /// * index into `train`, and + /// * distance to the query. + pub neighbors: Vec>, +} + +impl AnnDataset { + /// Augment the dataset by adding noisy copies of the data. + #[must_use] + pub fn augment(mut self, multiplier: usize, error_rate: f32) -> Self { + ftlog::info!("Augmenting dataset to {multiplier}x..."); + self.train = symagen::augmentation::augment_data(&self.train, multiplier, error_rate); + + self + } + + /// Generate a random dataset with the given metric. + #[must_use] + pub fn gen_random(cardinality: usize, n_copies: usize, dimensionality: usize, n_queries: usize, seed: u64) -> Self { + let train = (0..n_copies) + .into_par_iter() + .flat_map(|i| { + let seed = seed + i.as_u64(); + symagen::random_data::random_tabular_seedable(cardinality, dimensionality, -1.0, 1.0, seed) + }) + .collect::>(); + let queries = symagen::random_data::random_tabular_seedable( + n_queries, + dimensionality, + -1.0, + 1.0, + seed + n_copies.as_u64(), + ); + + Self { + train, + queries, + neighbors: Vec::new(), + } + } +} diff --git a/benches/utils/src/ann_benchmarks/reader.rs b/benches/utils/src/ann_benchmarks/reader.rs new file mode 100644 index 000000000..5a53064f8 --- /dev/null +++ b/benches/utils/src/ann_benchmarks/reader.rs @@ -0,0 +1,155 @@ +//! Reading from the hdf5 datasets provided by the `ann-benchmarks` repository +//! on GitHub. + +use super::AnnDataset; + +/// `ann-benchmarks` datasets are expected to have this many neighbors. +const ANN_NUM_NEIGHBORS: usize = 100; + +/// Reads an `ann-benchmarks` dataset from the given path. +/// +/// The dataset is expected be an HDF5 group with the following members: +/// +/// * `train`: The data to use for clustering. +/// * `test`: The queries to use for search. +/// * `neighbors`: The true neighbors of each query, given as indices into +/// `train`. +/// * `distances`: The distances from each query to its true neighbors. +/// +/// If the dataset is flattened (as for `Kosarak` and `MovieLens10M`), then we +/// expect the following additional members: +/// +/// * `size_train`: The lengths of each inner vector in `train`. +/// * `size_test`: The lengths of each inner vector in `test`. +/// +/// # Arguments +/// +/// * `path`: The path to the dataset. +/// * `flattened`: Whether to read the `train` and `test` datasets as flattened +/// vectors. +/// +/// # Returns +/// +/// The dataset, if it was read successfully. +/// +/// # Errors +/// +/// * If the dataset is not readable. +/// * If the dataset is not in the expected format. +pub fn read, T: hdf5::H5Type + Clone>( + path: &P, + flattened: bool, +) -> Result, String> { + let path = path.as_ref(); + + let file = hdf5::File::open(path).map_err(|e| e.to_string())?; + ftlog::info!("Opened file: {path:?}"); + + ftlog::info!("Reading raw train and test datasets..."); + let train_raw = file.dataset("train").map_err(|e| e.to_string())?; + let test_raw = file.dataset("test").map_err(|e| e.to_string())?; + + let (train, queries) = if flattened { + // Read flattened dataset + let train = train_raw.read_raw::().map_err(|e| e.to_string())?; + let size_train = file + .dataset("size_train") + .map_err(|e| e.to_string())? + .read_raw::() + .map_err(|e| e.to_string())?; + + // Read flattened queries + let queries = test_raw.read_raw::().map_err(|e| e.to_string())?; + let size_test = file + .dataset("size_test") + .map_err(|e| e.to_string())? + .read_raw::() + .map_err(|e| e.to_string())?; + + ftlog::info!("Un-flattening datasets..."); + (un_flatten(train, &size_train)?, un_flatten(queries, &size_test)?) + } else { + // Read 2d-array dataset + let train = train_raw + .read_2d::() + .map_err(|e| e.to_string())? + .rows() + .into_iter() + .map(|r| r.to_vec()) + .collect(); + + // Read 2d-array queries + let queries = test_raw + .read_2d::() + .map_err(|e| e.to_string())? + .rows() + .into_iter() + .map(|r| r.to_vec()) + .collect(); + + (train, queries) + }; + ftlog::info!("Parsed {} train and {} query items.", train.len(), queries.len()); + + ftlog::info!("Reading true neighbors and distances..."); + let neighbors = { + let neighbors = file.dataset("neighbors").map_err(|e| e.to_string())?; + let neighbors = neighbors.read_raw::().map_err(|e| e.to_string())?; + + let distances = file.dataset("distances").map_err(|e| e.to_string())?; + let distances = distances.read_raw::().map_err(|e| e.to_string())?; + + if neighbors.len() != distances.len() { + return Err(format!( + "`neighbors` and `distances` have different lengths! {} vs {}", + neighbors.len(), + distances.len() + )); + } + + let neighbors = neighbors.into_iter().zip(distances).collect(); + let sizes = vec![ANN_NUM_NEIGHBORS; queries.len()]; + un_flatten(neighbors, &sizes)? + }; + + Ok(AnnDataset { + train, + queries, + neighbors, + }) +} + +/// Un-flattens a vector of data into a vector of vectors. +/// +/// # Arguments +/// +/// * `data` - The data to un-flatten. +/// * `sizes` - The sizes of the inner vectors. +/// +/// # Returns +/// +/// A vector of vectors where each inner vector has the size specified in `sizes`. +/// +/// # Errors +/// +/// * If the number of elements in `data` is not equal to the sum of the elements in `sizes`. +fn un_flatten(data: Vec, sizes: &[usize]) -> Result>, String> { + let num_elements: usize = sizes.iter().sum(); + if data.len() != num_elements { + return Err(format!( + "Incorrect number of elements. Expected: {num_elements}. Found: {}.", + data.len() + )); + } + + let mut iter = data.into_iter(); + let mut items = Vec::with_capacity(sizes.len()); + for &s in sizes { + let mut inner = Vec::with_capacity(s); + for _ in 0..s { + inner.push(iter.next().ok_or("Not enough elements!")?); + } + items.push(inner); + } + Ok(items) +} diff --git a/benches/utils/src/fasta/mod.rs b/benches/utils/src/fasta/mod.rs new file mode 100644 index 000000000..b04d912d4 --- /dev/null +++ b/benches/utils/src/fasta/mod.rs @@ -0,0 +1,115 @@ +//! Utilities for dealing with FASTA files. + +use std::path::Path; + +use abd_clam::{dataset::AssociatesMetadataMut, FlatVec}; +use rand::prelude::*; + +/// Reads a FASTA file from the given path. +/// +/// # Arguments +/// +/// * `path`: The path to the FASTA file. +/// * `holdout`: The number of sequences to hold out for queries. +/// * `remove_gaps`: Whether to remove gaps from the sequences. +/// +/// # Returns +/// +/// * The training sequences in a `FlatVec`. +/// * A `Vec` of the queries. +/// +/// # Errors +/// +/// * If the file does not exist. +/// * If the extension is not `.fasta`. +/// * If the file cannot be read as a FASTA file. +/// * If any ID or sequence is empty. +#[allow(clippy::type_complexity)] +pub fn read>( + path: &P, + holdout: usize, + remove_gaps: bool, +) -> Result<(FlatVec, Vec<(String, String)>), String> { + let path = path.as_ref(); + if !path.exists() { + return Err(format!("Path {path:?} does not exist!")); + } + + if !path.extension().map_or(false, |ext| ext == "fasta") { + return Err(format!("Path {path:?} does not have the `.fasta` extension!")); + } + + ftlog::info!("Reading FASTA file from {path:?}."); + + let mut records = bio::io::fasta::Reader::from_file(path) + .map_err(|e| e.to_string())? + .records(); + + // Create accumulator for sequences and track the min and max lengths. + let mut seqs = Vec::new(); + let (mut min_len, mut max_len) = (usize::MAX, 0); + + // Check each record for an empty ID or sequence. + while let Some(Ok(record)) = records.next() { + let name = record.id().to_string(); + if name.is_empty() { + return Err(format!("Empty ID for record {}.", seqs.len())); + } + + let seq: String = if remove_gaps { + record + .seq() + .iter() + .filter(|&c| *c != b'-') + .map(|&c| c as char) + .collect() + } else { + record.seq().iter().map(|&c| c as char).collect() + }; + + if seq.is_empty() { + return Err(format!("Empty sequence for record {}.", seqs.len())); + } + + if seq.len() < min_len { + min_len = seq.len(); + } + if seq.len() > max_len { + max_len = seq.len(); + } + + // Add the sequence to the accumulator. + seqs.push((name, seq)); + + if seqs.len() % 10_000 == 0 { + // Log progress every 10,000 sequences. + ftlog::info!("Read {} sequences...", seqs.len()); + } + } + + if seqs.is_empty() { + return Err("No sequences found!".to_string()); + } + ftlog::info!("Read {} sequences.", seqs.len()); + + // Shuffle the sequences and split off the queries. + let queries = if holdout > 0 { + let mut rng = rand::thread_rng(); + seqs.shuffle(&mut rng); + seqs.split_off(seqs.len() - holdout) + } else { + Vec::new() + }; + ftlog::info!("Holding out {} queries.", queries.len()); + + // Unzip the IDs and sequences. + let (ids, seqs): (Vec<_>, Vec<_>) = seqs.into_iter().unzip(); + + // Create the FlatVec. + let data = FlatVec::new(seqs)? + .with_dim_lower_bound(min_len) + .with_dim_upper_bound(max_len) + .with_metadata(&ids)?; + + Ok((data, queries)) +} diff --git a/benches/utils/src/lib.rs b/benches/utils/src/lib.rs new file mode 100644 index 000000000..34566e7c8 --- /dev/null +++ b/benches/utils/src/lib.rs @@ -0,0 +1,222 @@ +#![deny(clippy::correctness)] +#![warn( + missing_docs, + clippy::all, + clippy::suspicious, + clippy::style, + clippy::complexity, + clippy::perf, + clippy::pedantic, + clippy::nursery, + clippy::missing_docs_in_private_items, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::cast_lossless +)] +//! Utilities for running benchmarks in CLAM. + +use ftlog::{ + appender::{FileAppender, Period}, + LevelFilter, LoggerGuard, +}; + +pub mod ann_benchmarks; +pub mod fasta; +pub mod metrics; +pub mod radio_ml; +pub mod reports; +pub mod types; + +pub use metrics::Complex; + +/// Configures the logger. +/// +/// # Errors +/// +/// - If a logs directory could not be located/created. +/// - If the logger could not be initialized. +pub fn configure_logger(file_name: &str) -> Result<(LoggerGuard, std::path::PathBuf), String> { + let root_dir = std::path::PathBuf::from(".") + .canonicalize() + .map_err(|e| e.to_string())?; + let logs_dir = root_dir.join("logs"); + if !logs_dir.exists() { + std::fs::create_dir(&logs_dir).map_err(|e| e.to_string())?; + } + let log_path = logs_dir.join(format!("{file_name}.log")); + + let writer = FileAppender::builder().path(&log_path).rotate(Period::Day).build(); + + let err_path = log_path.with_extension("err.log"); + + let guard = ftlog::Builder::new() + // global max log level + .max_log_level(LevelFilter::Info) + // define root appender, pass None would write to stderr + .root(writer) + // write `Debug` and higher logs in ftlog::appender to `err_path` instead of `log_path` + .filter("ftlog::appender", "ftlog-appender", LevelFilter::Debug) + .appender("ftlog-appender", FileAppender::new(err_path)) + .try_init() + .map_err(|e| e.to_string())?; + + Ok((guard, log_path)) +} + +/// The datasets for benchmarking. +#[derive(clap::ValueEnum, Debug, Clone)] +#[allow(non_camel_case_types, clippy::doc_markdown, clippy::module_name_repetitions)] +#[non_exhaustive] +pub enum RawData { + /// The DeepImage-1B dataset. + #[clap(name = "deep-image")] + DeepImage, + /// The Fashion-MNIST dataset. + #[clap(name = "fashion-mnist")] + FashionMNIST, + /// The GIST dataset. + #[clap(name = "gist")] + GIST, + /// The GloVe 25 dataset. + #[clap(name = "glove-25")] + GloVe_25, + /// The GloVe 50 dataset. + #[clap(name = "glove-50")] + GloVe_50, + /// The GloVe 100 dataset. + #[clap(name = "glove-100")] + GloVe_100, + /// The GloVe 200 dataset. + #[clap(name = "glove-200")] + GloVe_200, + /// The Kosarak dataset. + #[clap(name = "kosarak")] + Kosarak, + /// The LastFM dataset. + #[clap(name = "lastfm")] + LastFM, + /// The MNIST dataset. + #[clap(name = "mnist")] + MNIST, + /// The MovieLens-10M dataset. + #[clap(name = "movielens")] + MovieLens, + /// The NyTimes dataset. + #[clap(name = "nytimes")] + NyTimes, + /// The SIFT-1M dataset. + #[clap(name = "sift")] + SIFT, + /// A Random dataset with the same dimensions as the SIFT dataset. + #[clap(name = "random")] + Random, + /// The Silva-SSU-Ref dataset. + #[clap(name = "silva-ssu-ref")] + SilvaSSURef, + /// The RadioML dataset. + #[clap(name = "radio-ml")] + RadioML, +} + +impl RawData { + /// The name of the dataset. + #[must_use] + pub const fn name(&self) -> &str { + match self { + Self::DeepImage => "deep-image", + Self::FashionMNIST => "fashion-mnist", + Self::GIST => "gist", + Self::GloVe_25 => "glove-25", + Self::GloVe_50 => "glove-50", + Self::GloVe_100 => "glove-100", + Self::GloVe_200 => "glove-200", + Self::Kosarak => "kosarak", + Self::LastFM => "lastfm", + Self::MNIST => "mnist", + Self::MovieLens => "movielens", + Self::NyTimes => "nytimes", + Self::SIFT => "sift", + Self::Random => "random", + Self::SilvaSSURef => "silva-SSU-Ref", + Self::RadioML => "radio-ml", + } + } + + /// Read a vector dataset from the given directory. + /// + /// The path of the dataset will be inferred from the dataset name and the + /// given `inp_dir` as `inp_dir/{name}.hdf5`. + /// + /// # Errors + /// + /// * If the path does not exist. + /// * If the dataset is not readable. + /// * If the dataset is not in the expected format. + pub fn read_vector, T: hdf5::H5Type + Clone>( + &self, + inp_dir: &P, + ) -> Result, String> { + let path = inp_dir.as_ref().join(format!("{}.hdf5", self.name())); + if path.exists() { + ann_benchmarks::read(&path, self.is_flattened()) + } else { + Err(format!("Dataset {} not found: {path:?}", self.name())) + } + } + + /// Whether the dataset is flattened. + #[must_use] + const fn is_flattened(&self) -> bool { + matches!(self, Self::Kosarak | Self::MovieLens) + } + + /// Whether the dataset is tabular. + #[must_use] + pub const fn is_tabular(&self) -> bool { + matches!( + self, + Self::DeepImage + | Self::FashionMNIST + | Self::GIST + | Self::GloVe_25 + | Self::GloVe_50 + | Self::GloVe_100 + | Self::GloVe_200 + | Self::MNIST + | Self::NyTimes + | Self::SIFT + | Self::Random + ) + } + + /// Whether the dataset is of member-sets. + #[must_use] + pub const fn is_set(&self) -> bool { + matches!(self, Self::Kosarak | Self::MovieLens | Self::LastFM) + } + + /// Whether the dataset is of 'omic sequences. + #[must_use] + pub const fn is_sequence(&self) -> bool { + matches!(self, Self::SilvaSSURef) + } + + /// The name of the metric to use for the dataset. + #[must_use] + pub const fn metric(&self) -> &str { + match self { + Self::FashionMNIST | Self::GIST | Self::MNIST | Self::SIFT | Self::Random => "euclidean", + Self::DeepImage + | Self::GloVe_25 + | Self::GloVe_50 + | Self::GloVe_100 + | Self::GloVe_200 + | Self::LastFM + | Self::NyTimes => "cosine", + Self::Kosarak | Self::MovieLens => "jaccard", + Self::SilvaSSURef => "hamming", + Self::RadioML => "dtw", + } + } +} diff --git a/benches/utils/src/metrics/dtw.rs b/benches/utils/src/metrics/dtw.rs new file mode 100644 index 000000000..882d2208e --- /dev/null +++ b/benches/utils/src/metrics/dtw.rs @@ -0,0 +1,121 @@ +//! The Dynamic Time Warping distance metric and `Complex` number. + +use abd_clam::{metric::ParMetric, Metric}; +use distances::number::Float; + +/// A complex number +#[derive(Debug, Clone, Copy, bitcode::Decode, bitcode::Encode)] +pub struct Complex { + /// Real part + re: F, + /// Imaginary part + im: F, +} + +impl From<(F, F)> for Complex { + fn from((re, im): (F, F)) -> Self { + Self { re, im } + } +} + +impl Complex { + /// Calculate the magnitude of the complex number + fn magnitude(&self) -> F { + (self.re * self.re + self.im * self.im).sqrt() + } + + /// Calculate the absolute difference between two complex numbers + fn abs_diff(&self, other: &Self) -> F { + Self { + re: self.re - other.re, + im: self.im - other.im, + } + .magnitude() + } +} + +/// Calculate Dynamic Time Warping distance between two sequences +#[allow(clippy::module_name_repetitions)] +pub fn dtw_distance(a: &[Complex], b: &[Complex]) -> F { + // Initialize the DP matrix + let mut data = vec![vec![F::MAX; b.len() + 1]; a.len() + 1]; + data[0][0] = F::ZERO; + + // Calculate cost matrix + for ai in 1..=(a.len() + 1) { + for bi in 1..=(b.len() + 1) { + let cost = a[ai - 1].abs_diff(&b[bi - 1]); + data[ai][bi] = cost + F::min(F::min(data[ai - 1][bi], data[ai][bi - 1]), data[ai - 1][bi - 1]); + } + } + + // Return final cost + data[a.len()][b.len()] +} + +/// The `Dynamic Time Warping` distance metric. +pub struct Dtw; + +impl]>> Metric for Dtw { + fn distance(&self, a: &I, b: &I) -> f32 { + dtw_distance(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "dtw" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl]> + Send + Sync> ParMetric for Dtw {} + +impl]>> Metric for Dtw { + fn distance(&self, a: &I, b: &I) -> f64 { + dtw_distance(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "dtw" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl]> + Send + Sync> ParMetric for Dtw {} diff --git a/benches/utils/src/metrics/jaccard.rs b/benches/utils/src/metrics/jaccard.rs new file mode 100644 index 000000000..8ed5de23b --- /dev/null +++ b/benches/utils/src/metrics/jaccard.rs @@ -0,0 +1,38 @@ +//! The `Jaccard` distance metric. + +use abd_clam::{metric::ParMetric, Metric}; + +/// The `Jaccard` distance metric. +pub struct Jaccard; + +impl> Metric for Jaccard { + fn distance(&self, a: &I, b: &I) -> f32 { + distances::sets::jaccard(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "jaccard" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl + Send + Sync> ParMetric for Jaccard {} diff --git a/benches/utils/src/metrics/levenshtein.rs b/benches/utils/src/metrics/levenshtein.rs new file mode 100644 index 000000000..df296344a --- /dev/null +++ b/benches/utils/src/metrics/levenshtein.rs @@ -0,0 +1,39 @@ +//! The `Levenshtein` edit distance metric. + +use abd_clam::{metric::ParMetric, Metric}; +use distances::Number; + +/// The `Levenshtein` edit distance metric. +pub struct Levenshtein; + +impl, T: Number> Metric for Levenshtein { + fn distance(&self, a: &I, b: &I) -> T { + T::from(stringzilla::sz::edit_distance(a.as_ref(), b.as_ref())) + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + true + } +} + +impl + Send + Sync, T: Number> ParMetric for Levenshtein {} diff --git a/benches/utils/src/metrics/mod.rs b/benches/utils/src/metrics/mod.rs new file mode 100644 index 000000000..a93508b9f --- /dev/null +++ b/benches/utils/src/metrics/mod.rs @@ -0,0 +1,9 @@ +//! Distance metrics for benchmarks. + +mod dtw; +mod jaccard; +mod levenshtein; + +pub use dtw::{dtw_distance, Complex, Dtw}; +pub use jaccard::Jaccard; +pub use levenshtein::Levenshtein; diff --git a/benches/utils/src/radio_ml/mod.rs b/benches/utils/src/radio_ml/mod.rs new file mode 100644 index 000000000..baeafd8ba --- /dev/null +++ b/benches/utils/src/radio_ml/mod.rs @@ -0,0 +1,7 @@ +//! Dealing with the `RadioML` dataset. + +mod modulation_mode; +mod reader; + +pub use modulation_mode::ModulationMode; +pub use reader::read_mod; diff --git a/benches/utils/src/radio_ml/modulation_mode.rs b/benches/utils/src/radio_ml/modulation_mode.rs new file mode 100644 index 000000000..6505e07fe --- /dev/null +++ b/benches/utils/src/radio_ml/modulation_mode.rs @@ -0,0 +1,101 @@ +//! The modulation modes in the `RadioML` dataset. + +/// The modulation modes in the `RadioML` dataset. +#[allow(non_camel_case_types, missing_docs)] +pub enum ModulationMode { + APSK_128, + APSK_64, + APSK_32, + APSK_16, + PSK_32, + PSK_16, + PSK_8, + ASK_8, + ASK_4, + QAM_256, + QAM_128, + QAM_64, + QAM_32, + QAM_16, + AM_DSB_SC, + AM_DSB_WC, + AM_SSB_SC, + AM_SSB_WC, + BPSK, + FM, + GMSK, + OOK, + OQPSK, + QPSK, +} + +impl ModulationMode { + /// Returns all the modulation modes in the `RadioML` dataset. + #[must_use] + pub const fn all() -> [Self; 24] { + [ + Self::APSK_128, + Self::APSK_64, + Self::APSK_32, + Self::APSK_16, + Self::PSK_32, + Self::PSK_16, + Self::PSK_8, + Self::ASK_8, + Self::ASK_4, + Self::QAM_256, + Self::QAM_128, + Self::QAM_64, + Self::QAM_32, + Self::QAM_16, + Self::AM_DSB_SC, + Self::AM_DSB_WC, + Self::AM_SSB_SC, + Self::AM_SSB_WC, + Self::BPSK, + Self::FM, + Self::GMSK, + Self::OOK, + Self::OQPSK, + Self::QPSK, + ] + } + + /// Returns the name of the modulation mode. + #[must_use] + pub const fn name(&self) -> &str { + match self { + Self::APSK_128 => "128APSK", + Self::APSK_64 => "64APSK", + Self::APSK_32 => "32APSK", + Self::APSK_16 => "16APSK", + Self::PSK_32 => "32PSK", + Self::PSK_16 => "16PSK", + Self::PSK_8 => "8PSK", + Self::ASK_8 => "8ASK", + Self::ASK_4 => "4ASK", + Self::QAM_256 => "256QAM", + Self::QAM_128 => "128QAM", + Self::QAM_64 => "64QAM", + Self::QAM_32 => "32QAM", + Self::QAM_16 => "16QAM", + Self::AM_DSB_SC => "AM-DSB-SC", + Self::AM_DSB_WC => "AM-DSB-WC", + Self::AM_SSB_SC => "AM-SSB-SC", + Self::AM_SSB_WC => "AM-SSB-WC", + Self::BPSK => "BPSK", + Self::FM => "FM", + Self::GMSK => "GMSK", + Self::OOK => "OOK", + Self::OQPSK => "OQPSK", + Self::QPSK => "QPSK", + } + } + + /// Returns the path to the `h5` file containing the modulation mode. + pub fn h5_path>(&self, inp_dir: &P) -> std::path::PathBuf { + let mut path = inp_dir.as_ref().to_path_buf(); + path.push(format!("mod_{}.h5", self.name())); + path + } +} diff --git a/benches/utils/src/radio_ml/reader.rs b/benches/utils/src/radio_ml/reader.rs new file mode 100644 index 000000000..4b1c55fb7 --- /dev/null +++ b/benches/utils/src/radio_ml/reader.rs @@ -0,0 +1,112 @@ +//! Reading the `RadioML` dataset. + +use ndarray::prelude::*; + +use crate::Complex; + +/// Returns the 26 SNR levels in the `RadioML` dataset. +/// +/// These are (-20..=30) dB in steps of 2 dB. +fn snr_levels() -> [i32; 26] { + (-20..=30) + .step_by(2) + .collect::>() + .try_into() + .unwrap_or_else(|_| unreachable!("We have the correct number of iterations.")) +} + +/// Reads a single modulation mode from the `RadioML` dataset. +/// +/// # Arguments +/// +/// * `inp_dir` - The input directory containing the `RadioML` dataset. +/// * `mode` - The modulation mode to read. +/// * `snr` - The SNR level to read. If `None`, all SNR levels are read. +/// +/// # Returns +/// +/// A 2D array with shape `[n_samples, n_features]` containing the signals for +/// the given modulation mode and SNR level. +/// +/// # Errors +/// +/// - If the file does not exist. +/// - If the dataset does not exist. +/// - If the dataset has the wrong shape. +/// - If the dataset cannot be read. +/// - If the SNR level is not found. +/// - If the dataset has the wrong shape. +pub fn read_mod>( + inp_dir: &P, + mode: &super::ModulationMode, + snr: Option, +) -> Result>>, String> { + let h5_path = mode.h5_path(inp_dir); + if !h5_path.exists() { + return Err(format!("File {h5_path:?} does not exist!")); + } + + let file = hdf5::File::open(&h5_path).map_err(|e| format!("Error opening file: {e}"))?; + ftlog::debug!("Opened file {h5_path:?}: {file:?}"); + + let data_raw = file.dataset("X").map_err(|e| format!("Error opening dataset: {e}"))?; + ftlog::debug!("Opened dataset X: {data_raw:?}"); + + let train = data_raw + .read_dyn::() + .map_err(|e| format!("Error reading dataset: {e}"))?; + ftlog::debug!("Read dataset X with shape: {:?}", train.shape()); + + // `train` should be a 3D array with shape `[106_496, 1_024, 2]`. Each + // sample is a 2D array with shape `[1_024, 2]` containing the real and + // imaginary parts of the signal. + if train.ndim() != 3 { + return Err(format!("Expected 3D array, got {}D array!", train.ndim())); + } + if train.shape() != [106_496, 1_024, 2] { + return Err(format!("Expected shape [106_496, 1_024, 2], got {:?}!", train.shape())); + } + + // Convert the array from a dynamic array to a 3D array. + let train = train + .into_shape_with_order([106_496, 1_024, 2]) + .map_err(|e| format!("Error reshaping dataset: {e}"))?; + + // The samples are stored in chunks of 4_096 samples, each corresponding to + // a different SNR level. We need to extract the samples corresponding to + // the given SNR level if it is provided. + let train = if let Some(snr) = snr { + ftlog::debug!("Extracting samples for SNR level {snr}."); + let snr_idx = snr_levels() + .into_iter() + .position(|x| x == snr) + .ok_or_else(|| format!("SNR level {snr} not found!"))?; + let start = snr_idx * 4_096; + let end = start + 4_096; + train.slice(s![start..end, .., ..]).to_owned() + } else { + train + }; + ftlog::debug!("Extracted samples for SNR level: {:?}", train.shape()); + + // For each sample, take the magnitude of the complex number and convert the + // full array to a 2D array with shape `[106_496, 1_024]`. + let train = train + .outer_iter() + .map(|sample| { + sample + .rows() + .into_iter() + .map(|r| Complex::::from((r[0], r[1]))) + .collect::>() + }) + .collect::>(); + ftlog::debug!( + "Converted dataset from complex to real: {:?}", + (train.len(), train[0].len()) + ); + + ftlog::info!("Read {} signals for {}.", train.len(), mode.name()); + + Ok(train) +} diff --git a/benches/utils/src/reports/cakes.rs b/benches/utils/src/reports/cakes.rs new file mode 100644 index 000000000..8f2f0467f --- /dev/null +++ b/benches/utils/src/reports/cakes.rs @@ -0,0 +1,299 @@ +//! Reports for Cakes results. + +use distances::Number; + +/// Reports for Cakes results search. +pub struct Results { + /// The dataset name. + dataset: String, + /// The cardinality of the dataset. + cardinality: usize, + /// The dimensionality of the dataset. + dimensionality: usize, + /// The metric used. + metric: String, + /// A vector of: + /// + /// - The name of the `Cluster` type. + /// - The name of the search algorithm. + /// - The value of radius. + /// - The mean time taken (seconds per query) to perform the search. + /// - The mean throughput (queries per second). + /// - The mean number of hits. + /// - The mean recall. + /// - The mean number of distance computations per query. + #[allow(clippy::type_complexity)] + radial_results: Vec<(String, String, T, f32, f32, f32, f32, f32)>, + /// A vector of: + /// + /// - The name of the `Cluster` type. + /// - The name of the search algorithm. + /// - The value of k. + /// - The mean time taken (seconds per query) to perform the search. + /// - The mean throughput (queries per second). + /// - The mean number of hits. + /// - The mean recall. + /// - The mean number of distance computations per query. + #[allow(clippy::type_complexity)] + k_results: Vec<(String, String, usize, f32, f32, f32, f32, f32)>, +} + +impl Results { + /// Create a new report. + #[must_use] + pub fn new(data_name: &str, cardinality: usize, dimensionality: usize, metric: &str) -> Self { + Self { + dataset: data_name.to_string(), + cardinality, + dimensionality, + metric: metric.to_string(), + radial_results: Vec::new(), + k_results: Vec::new(), + } + } + + /// Add a new result for radial search. + #[allow(clippy::too_many_arguments)] + pub fn append_radial_result( + &mut self, + cluster: &str, + algorithm: &str, + radius: T, + time: f32, + throughput: f32, + output_sizes: &[usize], + recalls: &[f32], + distance_count: f32, + ) { + let mean_output_size = abd_clam::utils::mean(output_sizes); + let mean_recall = abd_clam::utils::mean(recalls); + self.radial_results.push(( + cluster.to_string(), + algorithm.to_string(), + radius, + time, + throughput, + mean_output_size, + mean_recall, + distance_count, + )); + self.log_last_radial(); + } + + /// Add a new result for k-NN search. + #[allow(clippy::too_many_arguments)] + pub fn append_k_result( + &mut self, + cluster: &str, + algorithm: &str, + k: usize, + time: f32, + throughput: f32, + output_sizes: &[usize], + recalls: &[f32], + distance_count: f32, + ) { + let mean_output_size = abd_clam::utils::mean(output_sizes); + let mean_recall = abd_clam::utils::mean(recalls); + self.k_results.push(( + cluster.to_string(), + algorithm.to_string(), + k, + time, + throughput, + mean_output_size, + mean_recall, + distance_count, + )); + self.log_last_k(); + } + + /// Logs the last radial record. + fn log_last_radial(&self) { + let mut parts = vec![ + format!("Dataset: {}", self.dataset), + format!("Cardinality: {}", self.cardinality), + format!("Dimensionality: {}", self.dimensionality), + format!("Metric: {}", self.metric), + ]; + if let Some((cluster, algorithm, radius, time, throughput, output_size, recall, distance_computations)) = + self.radial_results.last() + { + parts.push(format!("Cluster: {cluster}")); + parts.push(format!("Algorithm: {algorithm}")); + parts.push(format!("Radius: {radius}")); + parts.push(format!("Time: {time}")); + parts.push(format!("Throughput: {throughput}")); + parts.push(format!("Output size: {output_size}")); + parts.push(format!("Recall: {recall}")); + parts.push(format!("Distance computations: {distance_computations}")); + } + + ftlog::info!("{}", parts.join(", ")); + } + + /// Logs the last k record. + fn log_last_k(&self) { + let mut parts = vec![ + format!("Dataset: {}", self.dataset), + format!("Cardinality: {}", self.cardinality), + format!("Dimensionality: {}", self.dimensionality), + format!("Metric: {}", self.metric), + ]; + if let Some((cluster, algorithm, k, time, throughput, output_size, recall, distance_computations)) = + self.k_results.last() + { + parts.push(format!("Cluster: {cluster}")); + parts.push(format!("Algorithm: {algorithm}")); + parts.push(format!("k: {k}")); + parts.push(format!("Time: {time}")); + parts.push(format!("Throughput: {throughput}")); + parts.push(format!("Output size: {output_size}")); + parts.push(format!("Recall: {recall}")); + parts.push(format!("Distance computations: {distance_computations}")); + } + + ftlog::info!("{}", parts.join(", ")); + } + + /// Write the report to a csv file. + /// + /// # Errors + /// + /// - If the file cannot be created. + /// - If the header cannot be written. + /// - If a record cannot be written. + pub fn write_to_csv>(&self, dir: P) -> Result<(), String> { + let name = format!( + "{}_{}_{}_{}.csv", + self.dataset, self.cardinality, self.dimensionality, self.metric + ); + + let header = [ + "cluster", + "algorithm", + "radius", + "k", + "time", + "throughput", + "mean_output_size", + "mean_recall", + "mean_distance_computations", + ]; + + let path = dir.as_ref().join(name); + let mut writer = csv::Writer::from_path(path).map_err(|e| e.to_string())?; + writer.write_record(header).map_err(|e| e.to_string())?; + + for (cluster, algorithm, radius, time, throughput, output_size, recall, distance_computations) in + &self.radial_results + { + writer + .write_record([ + cluster, + algorithm, + &radius.to_string(), + "", + &time.to_string(), + &throughput.to_string(), + &output_size.to_string(), + &recall.to_string(), + &distance_computations.to_string(), + ]) + .map_err(|e| e.to_string())?; + } + + for (cluster, algorithm, k, time, throughput, output_size, recall, distance_computations) in &self.k_results { + writer + .write_record([ + cluster, + algorithm, + "", + &k.to_string(), + &time.to_string(), + &throughput.to_string(), + &output_size.to_string(), + &recall.to_string(), + &distance_computations.to_string(), + ]) + .map_err(|e| e.to_string())?; + } + + Ok(()) + } + + /// Read the report from a csv file. + /// + /// # Errors + /// + /// - If the file name is invalid. + /// - If the record length is invalid. + /// - If the record parts cannot be parsed. + pub fn read_from_csv>(path: P) -> Result { + let name = path.as_ref().file_stem().ok_or("Could not get file stem")?; + let name = name.to_string_lossy().to_string().replace(".csv", ""); + let name_parts = name.split('_').collect::>(); + if name_parts.len() != 4 { + return Err(format!( + "Invalid file name. Should have 6 parts separated by underscores: {name_parts:?}" + )); + } + let dataset = name_parts[0].to_string(); + let cardinality = name_parts[1].parse::().map_err(|e| e.to_string())?; + let dimensionality = name_parts[2].parse::().map_err(|e| e.to_string())?; + let metric = name_parts[3].to_string(); + + let mut reader = csv::Reader::from_path(path).map_err(|e| e.to_string())?; + let mut radial_results = Vec::new(); + let mut k_results = Vec::new(); + + for record in reader.records() { + let record = record.map_err(|e| e.to_string())?; + if record.len() != 7 { + return Err(format!("Invalid record length. Should have 7 parts: {record:?}")); + } + let cluster = record[0].to_string(); + let algorithm = record[1].to_string(); + let time = record[3].parse::().map_err(|e| e.to_string())?; + let throughput = record[4].parse::().map_err(|e| e.to_string())?; + let output_size = record[5].parse::().map_err(|e| e.to_string())?; + let recall = record[6].parse::().map_err(|e| e.to_string())?; + let distance_computations = record[7].parse::().map_err(|e| e.to_string())?; + + if let Ok(radius) = T::from_str(&record[2]) { + radial_results.push(( + cluster, + algorithm, + radius, + time, + throughput, + output_size, + recall, + distance_computations, + )); + } else if let Ok(k) = record[2].parse::() { + k_results.push(( + cluster, + algorithm, + k, + time, + throughput, + output_size, + recall, + distance_computations, + )); + } else { + return Err("Could not parse T or usize from string".to_string()); + } + } + + Ok(Self { + dataset, + cardinality, + dimensionality, + metric, + radial_results, + k_results, + }) + } +} diff --git a/benches/utils/src/reports/mod.rs b/benches/utils/src/reports/mod.rs new file mode 100644 index 000000000..df51272dc --- /dev/null +++ b/benches/utils/src/reports/mod.rs @@ -0,0 +1,5 @@ +//! Serializable reports for results from benchmarks. + +mod cakes; + +pub use cakes::Results as CakesResults; diff --git a/benches/utils/src/types.rs b/benches/utils/src/types.rs new file mode 100644 index 000000000..46f58b03d --- /dev/null +++ b/benches/utils/src/types.rs @@ -0,0 +1,64 @@ +//! Helper types for benchmarks. + +use abd_clam::pancakes::{Decodable, Encodable}; +use distances::Number; + +/// A wrapper around a vector for use in benchmarks. +#[derive(Clone, bitcode::Encode, bitcode::Decode)] +pub struct Row(Vec); + +impl Row { + /// Converts a `Row` to a vector. + #[allow(dead_code)] + #[must_use] + pub fn to_vec(v: Self) -> Vec { + v.0 + } +} + +impl From> for Row { + fn from(v: Vec) -> Self { + Self(v) + } +} + +impl FromIterator for Row { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl AsRef<[F]> for Row { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + +impl Encodable for Row { + fn as_bytes(&self) -> Box<[u8]> { + self.0 + .iter() + .flat_map(|v| v.to_le_bytes()) + .collect::>() + .into_boxed_slice() + } + + fn encode(&self, reference: &Self) -> Box<[u8]> { + let diffs = reference.0.iter().zip(self.0.iter()).map(|(&a, &b)| a - b).collect(); + Self::as_bytes(&diffs) + } +} + +impl Decodable for Row { + fn from_bytes(bytes: &[u8]) -> Self { + bytes + .chunks_exact(std::mem::size_of::()) + .map(F::from_le_bytes) + .collect() + } + + fn decode(reference: &Self, bytes: &[u8]) -> Self { + let diffs = Self::from_bytes(bytes); + reference.0.iter().zip(diffs.0.iter()).map(|(&a, &b)| a - b).collect() + } +} diff --git a/crates/abd-clam/.bumpversion.cfg b/crates/abd-clam/.bumpversion.cfg index 80ff46742..0a16acfb6 100644 --- a/crates/abd-clam/.bumpversion.cfg +++ b/crates/abd-clam/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.31.0 +current_version = 0.32.0 commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? diff --git a/crates/abd-clam/Cargo.toml b/crates/abd-clam/Cargo.toml index 27fca03d5..0893ef5f2 100644 --- a/crates/abd-clam/Cargo.toml +++ b/crates/abd-clam/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "abd-clam" -version = "0.31.0" +version = "0.32.0" authors = [ "Najib Ishaq ", "Tom Howard ", @@ -10,7 +10,7 @@ authors = [ "Oliver McLaughlin ", ] edition = "2021" -rust-version = "1.79" +rust-version = "1.81" # using: `total_cmp` description = "Clustering, Learning and Approximation with Manifolds" license = "MIT" readme = "./README.md" @@ -29,14 +29,7 @@ publish = true distances = { workspace = true } rayon = { workspace = true } rand = { workspace = true } -serde = { workspace = true } -flate2 = { workspace = true } ftlog = { workspace = true } -serde_arrays = "0.1.0" - -# For: -# - IO of datasets -csv = { version = "1.3.0" , optional = true } # For: # - IO from npy files @@ -46,35 +39,46 @@ ndarray-npy = { workspace = true, optional = true } # For: # - CHAODA -smartcore = { git = "https://github.com/smartcorelib/smartcore.git", rev = "239c00428f7448d30b78bf8653923f6bc0e2c29b", features = ["serde"], optional = true } -# linfa = { version = "0.7.0", features = ["serde"], optional = true } -# linfa-linear = { version = "0.7.0", features = ["serde"], optional = true } -ordered-float = { version = "4.2.2", optional = true } -bincode = { workspace = true, optional = true } +smartcore = { version = "0.4", features = ["serde"], optional = true } + +# For: +# - MSA +stringzilla = { workspace = true, optional = true } + +# For: +# - Disk I/O for most clusters and datasets +serde = { workspace = true, optional = true } +bitcode = { workspace = true, optional = true } +flate2 = { workspace = true, optional = true } +# - Reading and writing CSV files from clusters' properties and some datasets +csv = { workspace = true , optional = true } [dev-dependencies] symagen = { workspace = true } -bincode = { workspace = true } -criterion = { version = "0.5.1", features = ["html_reports"] } +bitcode = { workspace = true } +criterion = { version = "0.5", features = ["html_reports"] } tempdir = "0.3.7" -float-cmp = "0.9.0" +float-cmp = "0.10.0" test-case = "3.2.1" statistical = "1.0.0" [features] -csv = ["dep:csv"] # For writing trees to CSV files. -ndarray-bindings = ["dep:ndarray", "dep:ndarray-npy"] -chaoda = ["dep:smartcore", "dep:ordered-float", "dep:bincode"] +disk-io = ["dep:serde", "dep:csv", "dep:bitcode", "dep:flate2", "dep:ndarray", "dep:ndarray-npy"] +chaoda = ["dep:smartcore"] +mbed = ["chaoda"] +msa = ["dep:stringzilla"] +all = ["disk-io", "chaoda", "mbed", "msa"] [[bench]] -name = "genomic_search" +name = "vector_search" harness = false [[bench]] -name = "vector_search" +name = "genomic_search" harness = false +required-features = ["msa"] [[bench]] name = "ann_benchmarks" harness = false -required-features = ["ndarray-bindings"] +required-features = ["disk-io"] diff --git a/crates/abd-clam/README.md b/crates/abd-clam/README.md index 74e11aeb7..0d4884f88 100644 --- a/crates/abd-clam/README.md +++ b/crates/abd-clam/README.md @@ -1,4 +1,4 @@ -# CLAM: Clustering, Learning and Approximation with Manifolds (v0.31.0) +# CLAM: Clustering, Learning and Approximation with Manifolds (v0.32.0) The Rust implementation of CLAM. @@ -7,86 +7,82 @@ This means that the API is not yet stable and breaking changes may occur frequen ## Usage -CLAM is a library crate so you can add it to your crate using `cargo add abd_clam@0.31.0`. +CLAM is a library crate so you can add it to your crate using `cargo add abd_clam@0.32.0`. -### Cakes: Nearest Neighbor Search +## Features + +This crate provides the following features: + +- `disk-io`: Enables easy IO for several structs, primarily using `bitcode` and `serde`. +- `chaoda`: Enables anomaly detection using the CHAODA. +- `msa`: Enables multiple sequence alignment. +- `mbed`: Enables dimensionality reduction algorithms. +- `all`: Enables all features. + +### `Cakes`: Nearest Neighbor Search ```rust use abd_clam::{ - cakes::{cluster::Searchable, Algorithm}, - Ball, Cluster, FlatVec, Metric, Partition, + cakes::{self, SearchAlgorithm}, + cluster::Partition, + dataset::AssociatesMetadataMut, + Ball, Cluster, FlatVec, }; use rand::prelude::*; -/// The distance function with with to perform clustering and search. -/// -/// We use the `distances` crate for the distance function. -fn euclidean(x: &Vec, y: &Vec) -> f32 { - distances::simd::euclidean_f32(x, y) -} - -/// Generate some random data. You can use your own data here. -/// -/// CLAM can handle arbitrarily large datasets. We use a small one here for -/// demonstration. -/// -/// We use the `symagen` crate for generating interesting datasets for examples -/// and tests. +// Generate some random data. You can use your own data here. +// +// CLAM can handle arbitrarily large datasets. We use a small one here for +// demonstration. +// +// We use the `symagen` crate for generating interesting datasets for examples +// and tests. let seed = 42; let mut rng = rand::rngs::StdRng::seed_from_u64(seed); let (cardinality, dimensionality) = (1_000, 10); let (min_val, max_val) = (-1.0, 1.0); -let rows: Vec> = symagen::random_data::random_tabular( - cardinality, - dimensionality, - min_val, - max_val, - &mut rng, -); +let rows: Vec> = + symagen::random_data::random_tabular(cardinality, dimensionality, min_val, max_val, &mut rng); // We will generate some random labels for each point. let labels: Vec = rows.iter().map(|v| v[0] > 0.0).collect(); -// We have to create a `Metric` object to encapsulate the distance function and its properties. -let metric = Metric::new(euclidean, false); +// We use the `Euclidean` metric for this example. +let metric = abd_clam::metric::Euclidean; -// We can create a `Dataset` object. We make it mutable here so we can reorder it after building the tree. -let data = FlatVec::new(rows, metric).unwrap(); +// We can create a `Dataset` object and assign metadata. +let data = FlatVec::new(rows).unwrap().with_metadata(&labels).unwrap(); -// We can assign the labels as metadata to the dataset. -let data = data.with_metadata(labels).unwrap(); - -// We define the criteria for building the tree to partition the `Cluster`s until each contains a single point. -let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; +// We define the criteria for building the tree to partition the `Cluster`s +// until each contains a single point. +let criteria = |c: &Ball<_>| c.cardinality() > 1; // Now we create a tree. -let root = Ball::new_tree(&data, &criteria, Some(seed)); +let root = Ball::new_tree(&data, &metric, &criteria, Some(seed)); // We will use the origin as our query. -let query: Vec = vec![0.0; dimensionality]; +let query = vec![0_f32; dimensionality]; // We can now perform Ranged Nearest Neighbors search on the tree. let radius = 0.05; -let alg = Algorithm::RnnClustered(radius); -let rnn_results: Vec<(usize, f32)> = root.search(&data, &query, alg); +let alg = cakes::RnnClustered(radius); +let rnn_results: Vec<(usize, f32)> = alg.search(&data, &metric, &root, &query); // KNN search is also supported. let k = 10; -// The `KnnRepeatedRnn` algorithm starts RNN search with a small radius and increases it until it finds `k` neighbors. -let alg = Algorithm::KnnRepeatedRnn(k, 2.0); -let knn_results: Vec<(usize, f32)> = root.search(&data, &query, alg); +// The `KnnRepeatedRnn` algorithm starts RNN search with a small radius and +// increases it until it finds `k` neighbors. +let alg = cakes::KnnRepeatedRnn(k, 2.0); +let knn_results: Vec<(usize, f32)> = alg.search(&data, &metric, &root, &query); // The `KnnBreadthFirst` algorithm searches the tree in a breadth-first manner. -let alg = Algorithm::KnnBreadthFirst(k); -let knn_results: Vec<(usize, f32)> = root.search(&data, &query, alg); +let alg = cakes::KnnBreadthFirst(k); +let knn_results: Vec<(usize, f32)> = alg.search(&data, &metric, &root, &query); // The `KnnDepthFirst` algorithm searches the tree in a depth-first manner. -let alg = Algorithm::KnnDepthFirst(k); -let knn_results: Vec<(usize, f32)> = root.search(&data, &query, alg); - -// We can borrow the reordered labels from the model. -let labels: &[bool] = data.metadata(); +let alg = cakes::KnnDepthFirst(k); +let knn_results: Vec<(usize, f32)> = alg.search(&data, &metric, &root, &query); // We can use the results to get the labels of the points that are within the // radius of the query point. @@ -97,19 +93,28 @@ let rnn_labels: Vec = rnn_results.iter().map(|&(i, _)| labels[i]).collect( let knn_labels: Vec = knn_results.iter().map(|&(i, _)| labels[i]).collect(); ``` -### Compression and Search +### `PanCakes`: Compression and Compressive Search We also support compression of certain datasets and trees to reduce memory usage. We can then perform compressed search on the compressed dataset without having to decompress the whole dataset. - ```rust use abd_clam::{ - adapter::ParBallAdapter, - cakes::{cluster::ParSearchable, Algorithm, CodecData, SquishyBall}, - partition::ParPartition, - Ball, Cluster, FlatVec, Metric, MetricSpace, Permutable, + cakes::{self, ParSearchAlgorithm}, + cluster::{adapter::ParBallAdapter, ClusterIO, ParPartition}, + dataset::{AssociatesMetadataMut, DatasetIO}, + metric::Levenshtein, + msa::{Aligner, CostMatrix, Sequence}, + pancakes::{CodecData, SquishyBall}, + Ball, Cluster, Dataset, FlatVec, }; +// We need an aligner to align the sequences for compression and decompression. + +// We will be generating DNA/RNA sequence data for this example so we will use +// the default cost matrix for DNA sequences. +let cost_matrix = CostMatrix::::default(); +let aligner = Aligner::new(&cost_matrix, b'-'); + // We will generate some random string data using the `symagen` crate. let alphabet = "ACTGN".chars().collect::>(); let seed_length = 100; @@ -121,109 +126,89 @@ let clump_radius = 3_u32; let inter_clump_distance_range = (clump_radius * 5, clump_radius * 7); let len_delta = seed_length / 10; let (metadata, data) = symagen::random_edits::generate_clumped_data( - &seed_string, - penalties, - &alphabet, - num_clumps, - clump_size, - clump_radius, - inter_clump_distance_range, - len_delta, -) -.into_iter() -.unzip::<_, _, Vec<_>, Vec<_>>(); - -// The dataset will use the `levenshtein` distance function from the `distances` crate. -let distance_fn = |a: &String, b: &String| distances::strings::levenshtein::(a, b); -let metric = Metric::new(distance_fn, true); -let data = FlatVec::new(data, metric.clone()) - .unwrap() - .with_metadata(metadata.clone()) - .unwrap(); + &seed_string, + penalties, + &alphabet, + num_clumps, + clump_size, + clump_radius, + inter_clump_distance_range, + len_delta, + ) + .into_iter() + .map(|(m, d)| (m, Sequence::new(d, Some(&aligner)))) + .unzip::<_, _, Vec<_>, Vec<_>>(); + +// We create a `FlatVec` dataset from the sequence data and assign metadata. +let data = FlatVec::new(data).unwrap().with_metadata(&metadata).unwrap(); + +// The dataset will use the `levenshtein` distance metric. +let metric = Levenshtein; // We can serialize the dataset to disk without compression. let temp_dir = tempdir::TempDir::new("readme-tests").unwrap(); let flat_path = temp_dir.path().join("strings.flat_vec"); -let mut file = std::fs::File::create(&flat_path).unwrap(); -bincode::serialize_into(&mut file, &data).unwrap(); +data.write_to(&flat_path).unwrap(); // We build a tree from the dataset. -let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; +let criteria = |c: &Ball<_>| c.cardinality() > 1; let seed = Some(42); -let ball = Ball::par_new_tree(&data, &criteria, seed); +let ball = Ball::par_new_tree(&data, &metric, &criteria, seed); // We can serialize the tree to disk. let ball_path = temp_dir.path().join("strings.ball"); -let mut file = std::fs::File::create(&ball_path).unwrap(); -bincode::serialize_into(&mut file, &ball).unwrap(); +ball.write_to(&ball_path).unwrap(); // We can adapt the tree and dataset to allow for compression and compressed search. -let (squishy_ball, codec_data) = SquishyBall::par_from_ball_tree(ball, data); +let (squishy_ball, codec_data) = SquishyBall::par_from_ball_tree(ball, data, &metric); -// The metadata types still need to be adjusted manually. We are working on a solution for this. -let squishy_ball = squishy_ball.with_metadata_type::(); -let codec_data = codec_data.with_metadata(metadata).unwrap(); +// The metadata type still need to be adjusted manually. We are working on a solution for this. +let codec_data = codec_data.with_metadata(&metadata).unwrap(); // We can serialize the compressed dataset to disk. let codec_path = temp_dir.path().join("strings.codec_data"); -let mut file = std::fs::File::create(&codec_path).unwrap(); -bincode::serialize_into(&mut file, &codec_data).unwrap(); +codec_data.write_to(&codec_path).unwrap(); +// Note that serialization of `Sequence` types does not store the `Aligner`. // We can serialize the compressed tree to disk. let squishy_ball_path = temp_dir.path().join("strings.squishy_ball"); -let mut file = std::fs::File::create(&squishy_ball_path).unwrap(); -bincode::serialize_into(&mut file, &squishy_ball).unwrap(); +squishy_ball.write_to(&squishy_ball_path).unwrap(); -// We can perform compressed search on the compressed dataset. -let query = &seed_string; +// We can perform compressive search on the compressed dataset. +let query = &Sequence::new(seed_string, Some(&aligner)); let radius = 2; -let alg = Algorithm::RnnClustered(radius); -let results: Vec<(usize, u16)> = squishy_ball.par_search(&codec_data, query, alg); +let k = 10; + +let alg = cakes::RnnClustered(radius); +let results: Vec<(usize, u16)> = alg.par_search(&codec_data, &metric, &squishy_ball, query); assert!(!results.is_empty()); -let k = 10; -let alg = Algorithm::KnnRepeatedRnn(k, 2); -let results: Vec<(usize, u16)> = squishy_ball.par_search(&codec_data, query, alg); +let alg = cakes::KnnRepeatedRnn(k, 2); +let results: Vec<(usize, u16)> = alg.par_search(&codec_data, &metric, &squishy_ball, query); assert_eq!(results.len(), k); -let alg = Algorithm::KnnBreadthFirst(k); -let results: Vec<(usize, u16)> = squishy_ball.par_search(&codec_data, query, alg); +let alg = cakes::KnnBreadthFirst(k); +let results: Vec<(usize, u16)> = alg.par_search(&codec_data, &metric, &squishy_ball, query); assert_eq!(results.len(), k); -let alg = Algorithm::KnnDepthFirst(k); -let results: Vec<(usize, u16)> = squishy_ball.par_search(&codec_data, query, alg); +let alg = cakes::KnnDepthFirst(k); +let results: Vec<(usize, u16)> = alg.par_search(&codec_data, &metric, &squishy_ball, query); assert_eq!(results.len(), k); // The dataset can be deserialized from disk. -let mut flat_data: FlatVec = - bincode::deserialize_from(std::fs::File::open(&flat_path).unwrap()).unwrap(); -// Since functions cannot be serialized, we have to set the metric manually. -flat_data.set_metric(metric.clone()); +let flat_data: FlatVec, String> = FlatVec::read_from(&flat_path).unwrap(); // The tree can be deserialized from disk. -let ball: Ball> = - bincode::deserialize_from(std::fs::File::open(&ball_path).unwrap()).unwrap(); +let ball: Ball = Ball::read_from(&ball_path).unwrap(); // The compressed dataset can be deserialized from disk. -let mut codec_data: CodecData = - bincode::deserialize_from(std::fs::File::open(&codec_path).unwrap()).unwrap(); -// The metric has to be set manually. -codec_data.set_metric(metric.clone()); +let codec_data: CodecData, String> = CodecData::read_from(&codec_path).unwrap(); +// Since the serialization of `Sequence` types does not store the `Aligner`, we +// need to manually set it. +let codec_data = codec_data.transform_centers(|s| s.with_aligner(&aligner)); // The compressed tree can be deserialized from disk. -// You will forgive the long type signature. -let squishy_ball: SquishyBall< - String, - u16, - FlatVec, - CodecData, - Ball>, -> = bincode::deserialize_from( - std::fs::File::open(&squishy_ball_path) - .map_err(|e| e.to_string()) - .unwrap(), -) -.unwrap(); +let squishy_ball: SquishyBall> = SquishyBall::read_from(&squishy_ball_path).unwrap(); ``` ### Chaoda: Anomaly Detection diff --git a/crates/abd-clam/VERSION b/crates/abd-clam/VERSION index 26bea73e8..9eb2aa3f1 100644 --- a/crates/abd-clam/VERSION +++ b/crates/abd-clam/VERSION @@ -1 +1 @@ -0.31.0 +0.32.0 diff --git a/crates/abd-clam/benches/ann_benchmarks.rs b/crates/abd-clam/benches/ann_benchmarks.rs index 2fda8f8a9..fcca622c7 100644 --- a/crates/abd-clam/benches/ann_benchmarks.rs +++ b/crates/abd-clam/benches/ann_benchmarks.rs @@ -1,30 +1,25 @@ //! Benchmarks for the suite of ANN-Benchmarks datasets. use std::{ + collections::HashMap, ffi::OsStr, path::{Path, PathBuf}, }; use abd_clam::{ - adapter::{Adapter, ParBallAdapter}, - cakes::OffBall, - partition::ParPartition, - BalancedBall, Ball, Cluster, Dataset, FlatVec, Metric, Permutable, + cakes::{HintedDataset, PermutedBall}, + cluster::{adapter::ParBallAdapter, BalancedBall, ParPartition}, + dataset::AssociatesMetadataMut, + metric::{self, ParMetric}, + Ball, Cluster, Dataset, FlatVec, }; use criterion::*; +use utils::Row; mod utils; -/// Reads the training and query data of the given dataset from the directory. -pub fn read_ann_data_npy( - name: &str, - root: &Path, - metric: Metric, f32>, -) -> (FlatVec, f32, usize>, Vec>) { - AnnDataset::from_name(name).read(root, "npy", metric) -} - /// The datasets available in the `ann-benchmarks` repository. +#[allow(dead_code)] enum AnnDataset { DeepImage, FashionMnist, @@ -42,6 +37,7 @@ enum AnnDataset { impl AnnDataset { /// Returns the value of the `AnnDataset` enum from the given name. + #[allow(dead_code)] fn from_name(name: &str) -> Self { match name { "deep-image" => Self::DeepImage, @@ -78,13 +74,11 @@ impl AnnDataset { } /// Returns the path to the train and test files of the dataset. - fn paths, S: AsRef>(&self, root: P, ext: S) -> [PathBuf; 2] { + fn paths, S: AsRef>(&self, root: &P, ext: S) -> [PathBuf; 2] { let name = self.to_name(); [ - root.as_ref() - .join(format!("{}-train", name)) - .with_extension(ext.as_ref()), - root.as_ref().join(format!("{}-test", name)).with_extension(ext), + root.as_ref().join(format!("{name}-train")).with_extension(ext.as_ref()), + root.as_ref().join(format!("{name}-test")).with_extension(ext), ] } @@ -100,79 +94,181 @@ impl AnnDataset { /// /// - The training data on which to build the trees. /// - The query data to search for the nearest neighbors. - fn read, S: AsRef>( - &self, - root: P, - ext: S, - metric: Metric, f32>, - ) -> (FlatVec, f32, usize>, Vec>) { + fn read, S: AsRef>(&self, root: &P, ext: S) -> (FlatVec, usize>, Vec>) { let [train, test] = self.paths(root, ext); - println!("Reading train data from {:?}, {:?}", train, test); - let test = FlatVec::read_npy(test, metric).unwrap(); - let (metric, test, _, _, _, _) = test.deconstruct(); - let train = FlatVec::read_npy(train, metric).unwrap().with_name(self.to_name()); + println!("Reading train data from {train:?}, {test:?}"); + let test = FlatVec::read_npy(&test) + .unwrap() + .transform_items(Row::from) + .items() + .to_vec(); + println!("Finished reading test data with {} items.", test.len()); + let train = FlatVec::read_npy(&train) + .unwrap() + .with_name(self.to_name()) + .transform_items(Row::from); + println!("Finished reading train data with {} items.", train.cardinality()); (train, test) } } +fn run_search, f32>>( + c: &mut Criterion, + data: FlatVec, usize>, + metric: &M, + queries: &[Row], + radii_fractions: &[f32], + ks: &[usize], + seed: Option, + multiplier_error: Option<(usize, f32)>, +) { + let data = if let Some((multiplier, error)) = multiplier_error { + println!("Augmenting data to {multiplier}x with an error rate of {error:.2}."); + + let mut rows = data.items().iter().map(|r| Row::::to_vec(r)).collect::>(); + let new_rows = symagen::augmentation::augment_data(&rows, multiplier, error); + rows.extend(new_rows); + let rows = rows.into_iter().map(Row::from).collect(); + + let name = format!("augmented-{}-{multiplier}", data.name()); + FlatVec::new(rows) + .unwrap_or_else(|e| unreachable!("{e}")) + .with_name(&name) + } else { + data + }; + + println!("Creating ball ..."); + let criteria = |c: &Ball<_>| c.cardinality() > 1; + let ball = Ball::par_new_tree(&data, metric, &criteria, seed); + let radii = radii_fractions.iter().map(|&r| r * ball.radius()).collect::>(); + + println!("Creating balanced ball ..."); + let criteria = |c: &BalancedBall<_>| c.cardinality() > 1; + let balanced_ball = BalancedBall::par_new_tree(&data, metric, &criteria, seed).into_ball(); + + println!("Adding hints to data ..."); + // let (_, max_radius) = abd_clam::utils::arg_max(&radii).unwrap(); + // let (_, max_k) = abd_clam::utils::arg_max(ks).unwrap(); + let data = data + .transform_metadata(|&i| (i, HashMap::new())) + .with_hints_from_tree(&ball, metric) + .with_hints_from_tree(&balanced_ball, metric); + // .with_hints_from(metric, &balanced_ball, max_radius, max_k); + + println!("Creating permuted ball ..."); + let (perm_ball, perm_data) = PermutedBall::par_from_ball_tree(ball.clone(), data.clone(), metric); + println!("Creating permuted balanced ball ..."); + let (perm_balanced_ball, perm_balanced_data) = + PermutedBall::par_from_ball_tree(balanced_ball.clone(), data.clone(), metric); + + utils::compare_permuted( + c, + metric, + (&ball, &data), + (&balanced_ball, &perm_balanced_data), + (&perm_ball, &perm_data), + (&perm_balanced_ball, &perm_balanced_data), + None, + None, + queries, + &radii, + ks, + true, + ); +} + fn ann_benchmarks(c: &mut Criterion) { let root_str = std::env::var("ANN_DATA_ROOT").unwrap(); + println!("ANN data root: {root_str}"); let ann_data_root = std::path::Path::new(&root_str).canonicalize().unwrap(); - println!("ANN data root: {:?}", ann_data_root); - - let euclidean = |x: &Vec<_>, y: &Vec<_>| distances::vectors::euclidean(x, y); - let cosine = |x: &Vec<_>, y: &Vec<_>| distances::vectors::cosine(x, y); - let data_names: Vec<(&str, &str, Metric, f32>)> = vec![ - ("fashion-mnist", "euclidean", Metric::new(euclidean, false)), - ("glove-25", "cosine", Metric::new(cosine, false)), - ("sift", "euclidean", Metric::new(euclidean, false)), - ]; - - let data_pairs = data_names.into_iter().map(|(data_name, metric_name, metric)| { - ( - data_name, - metric_name, - read_ann_data_npy(data_name, &ann_data_root, metric), - ) - }); + println!("ANN data root: {ann_data_root:?}"); let seed = Some(42); - let radii = vec![]; - let ks = vec![10, 100]; + let radii_fractions = vec![0.001, 0.01, 0.1]; + let ks = vec![1, 10, 100]; let num_queries = 100; - for (data_name, metric_name, (data, queries)) in data_pairs { - let queries = &queries[0..num_queries]; - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let ball = Ball::par_new_tree(&data, &criteria, seed); - let (off_ball, perm_data) = OffBall::par_from_ball_tree(ball.clone(), data.clone()); - - let criteria = |c: &BalancedBall<_, _, _>| c.cardinality() > 1; - let balanced_ball = BalancedBall::par_new_tree(&data, &criteria, seed); - let (balanced_off_ball, balanced_perm_data) = { - let balanced_off_ball = OffBall::adapt_tree(balanced_ball.clone(), None); - let mut balanced_perm_data = data.clone(); - let permutation = balanced_off_ball.source().indices().collect::>(); - balanced_perm_data.permute(&permutation); - (balanced_off_ball, balanced_perm_data) - }; - - utils::compare_permuted( - c, - data_name, - metric_name, - (&ball, &data), - (&off_ball, &perm_data), - None, - (&balanced_ball, &data), - (&balanced_off_ball, &balanced_perm_data), - None, - &queries, - &radii, - &ks, - true, - ); - } + + let (f_mnist, queries) = AnnDataset::FashionMnist.read(&ann_data_root, "npy"); + let queries = &queries[..num_queries]; + run_search( + c, + f_mnist.clone(), + &metric::Euclidean, + queries, + &radii_fractions, + &ks, + seed, + None, + ); + run_search( + c, + f_mnist.clone(), + &metric::Euclidean, + queries, + &radii_fractions, + &ks, + seed, + Some((2, 1.0)), + ); + run_search( + c, + f_mnist.clone(), + &metric::Euclidean, + queries, + &radii_fractions, + &ks, + seed, + Some((4, 1.0)), + ); + run_search( + c, + f_mnist.clone(), + &metric::Euclidean, + queries, + &radii_fractions, + &ks, + seed, + Some((8, 1.0)), + ); + run_search( + c, + f_mnist.clone(), + &metric::Euclidean, + queries, + &radii_fractions, + &ks, + seed, + Some((16, 1.0)), + ); + run_search( + c, + f_mnist.clone(), + &metric::Euclidean, + queries, + &radii_fractions, + &ks, + seed, + Some((32, 1.0)), + ); + run_search( + c, + f_mnist, + &metric::Euclidean, + queries, + &radii_fractions, + &ks, + seed, + Some((64, 1.0)), + ); + + let (sift, queries) = AnnDataset::Sift.read(&ann_data_root, "npy"); + let queries = &queries[..num_queries]; + run_search(c, sift, &metric::Euclidean, queries, &radii_fractions, &ks, seed, None); + + let (glove_25, queries) = AnnDataset::Glove25.read(&ann_data_root, "npy"); + let queries = &queries[..num_queries]; + run_search(c, glove_25, &metric::Cosine, queries, &radii_fractions, &ks, seed, None); } criterion_group!(benches, ann_benchmarks); diff --git a/crates/abd-clam/benches/genomic_search.rs b/crates/abd-clam/benches/genomic_search.rs index 1afb63a56..f61bf8100 100644 --- a/crates/abd-clam/benches/genomic_search.rs +++ b/crates/abd-clam/benches/genomic_search.rs @@ -1,26 +1,25 @@ //! Benchmark for genomic search. -mod utils; +use std::collections::HashMap; use abd_clam::{ - adapter::{Adapter, ParAdapter, ParBallAdapter}, - cakes::{OffBall, SquishyBall}, - partition::ParPartition, - BalancedBall, Ball, Cluster, FlatVec, Metric, Permutable, + cakes::{HintedDataset, PermutedBall}, + cluster::{adapter::ParBallAdapter, BalancedBall, ParPartition}, + dataset::{AssociatesMetadata, AssociatesMetadataMut}, + metric::Levenshtein, + msa::{Aligner, CostMatrix, Sequence}, + pancakes::SquishyBall, + Ball, Cluster, Dataset, FlatVec, }; use criterion::*; use rand::prelude::*; -const METRICS: &[(&str, fn(&String, &String) -> u64)] = &[ - ("levenshtein", |x: &String, y: &String| { - distances::strings::levenshtein(x, y) - }), - // ("needleman-wunsch", |x: &String, y: &String| { - // distances::strings::nw_distance(x, y) - // }), -]; +mod utils; fn genomic_search(c: &mut Criterion) { + let matrix = CostMatrix::::default(); + let aligner = Aligner::new(&matrix, b'-'); + let seed_length = 250; let alphabet = "ACTGN".chars().collect::>(); let seed_string = symagen::random_edits::generate_random_string(seed_length, &alphabet); @@ -30,7 +29,7 @@ fn genomic_search(c: &mut Criterion) { let clump_radius = 10_u16; let inter_clump_distance_range = (50_u16, 80_u16); let len_delta = 10; - let (_, genomes) = symagen::random_edits::generate_clumped_data( + let (metadata, genomes) = symagen::random_edits::generate_clumped_data( &seed_string, penalties, &alphabet, @@ -41,6 +40,7 @@ fn genomic_search(c: &mut Criterion) { len_delta, ) .into_iter() + .map(|(m, seq)| (m, Sequence::new(seq, Some(&aligner)))) .unzip::<_, _, Vec<_>, Vec<_>>(); let seed = 42; @@ -57,43 +57,53 @@ fn genomic_search(c: &mut Criterion) { }; let seed = Some(seed); - let radii = vec![]; - let ks = vec![1, 10, 20]; - for &(metric_name, distance_fn) in METRICS { - let metric = Metric::new(distance_fn, true); - let data = FlatVec::new(genomes.clone(), metric).unwrap(); + let radii = vec![1_u32, 5, 10]; + let ks = vec![1, 10, 100]; + + let data = FlatVec::new(genomes) + .unwrap() + .with_metadata(&metadata) + .unwrap() + .with_name("genomic-search"); + let metric = Levenshtein; + + let criteria = |c: &Ball<_>| c.cardinality() > 1; + let ball = Ball::par_new_tree(&data, &metric, &criteria, seed); + + let criteria = |c: &BalancedBall<_>| c.cardinality() > 1; + let balanced_ball = BalancedBall::par_new_tree(&data, &metric, &criteria, seed).into_ball(); + + let (_, max_radius) = abd_clam::utils::arg_max(&radii).unwrap(); + let (_, max_k) = abd_clam::utils::arg_max(&ks).unwrap(); + let data = data + .transform_metadata(|s| (s.clone(), HashMap::new())) + .with_hints_from_tree(&ball, &metric) + .with_hints_from(&metric, &balanced_ball, max_radius, max_k); - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let ball = Ball::par_new_tree(&data, &criteria, seed); - let (off_ball, perm_data) = OffBall::par_from_ball_tree(ball.clone(), data.clone()); - let (squishy_ball, dec_data) = SquishyBall::par_from_ball_tree(ball.clone(), data.clone()); + let (perm_ball, perm_data) = PermutedBall::par_from_ball_tree(ball.clone(), data.clone(), &metric); + let (squishy_ball, dec_data) = SquishyBall::par_from_ball_tree(ball.clone(), data.clone(), &metric); + let dec_data = dec_data.with_metadata(perm_data.metadata()).unwrap(); - let criteria = |c: &BalancedBall<_, _, _>| c.cardinality() > 1; - let balanced_ball = BalancedBall::par_new_tree(&data, &criteria, seed); - let (balanced_off_ball, balanced_perm_data) = { - let balanced_off_ball = OffBall::par_adapt_tree(balanced_ball.clone(), None); - let mut balanced_perm_data = data.clone(); - let permutation = balanced_off_ball.source().indices().collect::>(); - balanced_perm_data.permute(&permutation); - (balanced_off_ball, balanced_perm_data) - }; + let (perm_balanced_ball, perm_balanced_data) = + PermutedBall::par_from_ball_tree(balanced_ball.clone(), data.clone(), &metric); + let (squishy_balanced_ball, dec_balanced_data) = + SquishyBall::par_from_ball_tree(balanced_ball.clone(), data.clone(), &metric); + let dec_balanced_data = dec_balanced_data.with_metadata(perm_balanced_data.metadata()).unwrap(); - utils::compare_permuted( - c, - "genomic-search", - metric_name, - (&ball, &data), - (&off_ball, &perm_data), - Some((&squishy_ball, &dec_data)), - (&balanced_ball, &data), - (&balanced_off_ball, &balanced_perm_data), - None, - &queries, - &radii, - &ks, - true, - ); - } + utils::compare_permuted( + c, + &metric, + (&ball, &data), + (&balanced_ball, &perm_balanced_data), + (&perm_ball, &perm_data), + (&perm_balanced_ball, &perm_balanced_data), + Some((&squishy_ball, &dec_data)), + Some((&squishy_balanced_ball, &dec_balanced_data)), + &queries, + &radii, + &ks, + true, + ); } criterion_group!(benches, genomic_search); diff --git a/crates/abd-clam/benches/utils/compare_permuted.rs b/crates/abd-clam/benches/utils/compare_permuted.rs index 1aa9ef938..8af78788d 100644 --- a/crates/abd-clam/benches/utils/compare_permuted.rs +++ b/crates/abd-clam/benches/utils/compare_permuted.rs @@ -1,197 +1,271 @@ //! Benchmarking utilities for comparing the performance of different algorithms //! on permuted datasets. +#![allow(unused_imports, unused_variables)] + +use std::collections::HashMap; + use abd_clam::{ cakes::{ - cluster::{ParSearchable, Searchable}, - Algorithm, CodecData, Decodable, Encodable, OffBall, ParCompressible, SquishyBall, + KnnBreadthFirst, KnnDepthFirst, KnnHinted, KnnLinear, KnnRepeatedRnn, ParSearchAlgorithm, ParSearchable, + PermutedBall, RnnClustered, RnnLinear, SearchAlgorithm, Searchable, }, - BalancedBall, Ball, + dataset::ParDataset, + metric::ParMetric, + pancakes::{CodecData, Decodable, Encodable, ParCompressible, SquishyBall}, + Ball, Cluster, Dataset, FlatVec, Metric, }; use criterion::*; use distances::Number; +use measurement::WallTime; -/// Compare the performance of different algorithms on permuted datasets. +/// Compare the performance of different algorithms on datasets using different +/// cluster types. /// /// # Parameters /// /// - `c`: The criterion context. -/// - `metric_name`: The name of the metric used to measure distances. -/// - `data`: The original dataset. -/// - `root`: The root of tree on the original dataset. -/// - `perm_data`: The permuted dataset. -/// - `perm_root`: The root of the tree on the permuted dataset. +/// - `metric`: The metric used to measure distances. +/// - `ball_data`: The original dataset and its ball. +/// - `balanced_ball_data`: The original dataset and its balanced ball. +/// - `perm_ball_data`: The permuted dataset and its permuted ball. +/// - `perm_balanced_ball_data`: The permuted dataset and its permuted balanced +/// ball. +/// - `dec_ball_data`: The permuted dataset and its squishy ball, if any. +/// - `dec_balanced_ball_data`: The permuted dataset and its balanced squishy +/// ball, if any. /// - `queries`: The queries to search for. -/// - `radii`: The radii to use for RNN algorithms. -/// - `ks`: The values of `k` to use for kNN algorithms. +/// - `radii`: The radii to use for the RNN algorithm. +/// - `ks`: The numbers of neighbors to search for. +/// - `par_only`: Whether to only benchmark the parallel algorithms. /// /// # Type Parameters /// /// - `I`: The type of the items in the dataset. /// - `U`: The type of the scalars used to measure distances. -/// - `Co`: The type of the original dataset. -pub fn compare_permuted( +/// - `Co`: The type of the compressible dataset. +/// - `M`: The type of the metric used to measure distances. +#[allow(clippy::too_many_arguments, clippy::type_complexity)] +pub fn compare_permuted( c: &mut Criterion, - data_name: &str, - metric_name: &str, - ball_data: (&Ball, &Co), - off_ball_data: (&OffBall>, &Co), - dec_ball_data: Option<( - &SquishyBall, Ball>, - &CodecData, - )>, - bal_ball_data: (&BalancedBall, &Co), - bal_off_ball_data: (&OffBall>, &Co), - bal_dec_ball_data: Option<( - &SquishyBall, BalancedBall>, - &CodecData, - )>, + metric: &M, + ball_data: (&Ball, &FlatVec)>), + balanced_ball_data: (&Ball, &FlatVec)>), + perm_ball_data: (&PermutedBall>, &FlatVec)>), + perm_balanced_ball_data: (&PermutedBall>, &FlatVec)>), + dec_ball_data: Option<(&SquishyBall>, &CodecData)>)>, + dec_balanced_ball_data: Option<(&SquishyBall>, &CodecData)>)>, queries: &[I], - radii: &[U], + radii: &[T], ks: &[usize], par_only: bool, ) where I: Encodable + Decodable + Send + Sync, - U: Number, - Co: ParCompressible, + T: Number + 'static, + M: ParMetric, + Me: Send + Sync, { - let algs = vec![ - Algorithm::KnnRepeatedRnn(ks[0], U::ONE.double()), - Algorithm::KnnBreadthFirst(ks[0]), - Algorithm::KnnDepthFirst(ks[0]), - ]; - - let mut group = c.benchmark_group(format!("{}-{}-RnnClustered", data_name, metric_name)); - group - .sample_size(10) - .sampling_mode(SamplingMode::Flat) - .throughput(Throughput::Elements(queries.len().as_u64())); - - let (ball, data) = ball_data; - let (off_ball, perm_data) = off_ball_data; + let mut algs: Vec<( + Box, M, FlatVec)>>>, + Box>, M, FlatVec)>>>, + Option>, M, CodecData)>>>>, + )> = Vec::new(); - let (bal_ball, bal_data) = bal_ball_data; - let (bal_off_ball, bal_perm_data) = bal_off_ball_data; - - for &radius in radii { - let alg = Algorithm::RnnClustered(radius); - - if !par_only { - group.bench_with_input(BenchmarkId::new("Ball", radius), &radius, |b, _| { - b.iter_with_large_drop(|| ball.batch_search(data, queries, alg)); - }); - group.bench_with_input(BenchmarkId::new("BalancedBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| bal_ball.batch_search(bal_data, queries, alg)); - }); - - group.bench_with_input(BenchmarkId::new("OffBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| off_ball.batch_search(perm_data, queries, alg)); - }); - group.bench_with_input(BenchmarkId::new("BalancedOffBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| bal_off_ball.batch_search(bal_perm_data, queries, alg)); - }); - - if let Some((dec_root, dec_data)) = dec_ball_data { - group.bench_with_input(BenchmarkId::new("SquishyBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| dec_root.batch_search(dec_data, queries, alg)); - }); - } - if let Some((dec_root, dec_data)) = bal_dec_ball_data { - group.bench_with_input(BenchmarkId::new("BalancedSquishyBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| dec_root.batch_search(dec_data, queries, alg)); - }); - } + for (i, &radius) in radii.iter().enumerate() { + if i == 0 { + algs.push(( + Box::new(RnnLinear(radius)), + Box::new(RnnLinear(radius)), + Some(Box::new(RnnLinear(radius))), + )); } - - group.bench_with_input(BenchmarkId::new("ParBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| ball.par_batch_search(data, queries, alg)); - }); - group.bench_with_input(BenchmarkId::new("ParBalancedBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| bal_ball.par_batch_search(bal_data, queries, alg)); - }); - - group.bench_with_input(BenchmarkId::new("ParOffBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| off_ball.par_batch_search(perm_data, queries, alg)); - }); - group.bench_with_input(BenchmarkId::new("ParBalancedOffBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| bal_off_ball.par_batch_search(bal_perm_data, queries, alg)); - }); - - if let Some((dec_root, dec_data)) = dec_ball_data { - group.bench_with_input(BenchmarkId::new("ParSquishyBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| dec_root.par_batch_search(dec_data, queries, alg)); - }); - } - if let Some((dec_root, dec_data)) = bal_dec_ball_data { - group.bench_with_input(BenchmarkId::new("ParBalancedSquishyBall", radius), &radius, |b, _| { - b.iter_with_large_drop(|| dec_root.par_batch_search(dec_data, queries, alg)); - }); + algs.push(( + Box::new(RnnClustered(radius)), + Box::new(RnnClustered(radius)), + Some(Box::new(RnnClustered(radius))), + )); + } + for (i, &k) in ks.iter().enumerate() { + if i == 0 { + algs.push(( + Box::new(KnnLinear(k)), + Box::new(KnnLinear(k)), + Some(Box::new(KnnLinear(k))), + )); } + algs.push(( + Box::new(KnnRepeatedRnn(k, T::ONE.double())), + Box::new(KnnRepeatedRnn(k, T::ONE.double())), + Some(Box::new(KnnRepeatedRnn(k, T::ONE.double()))), + )); + algs.push(( + Box::new(KnnBreadthFirst(k)), + Box::new(KnnBreadthFirst(k)), + Some(Box::new(KnnBreadthFirst(k))), + )); + algs.push(( + Box::new(KnnDepthFirst(k)), + Box::new(KnnDepthFirst(k)), + Some(Box::new(KnnDepthFirst(k))), + )); + algs.push((Box::new(KnnHinted(k)), Box::new(KnnHinted(k)), None)); } - group.finish(); - for alg in &algs { - let mut group = c.benchmark_group(format!("{}-{}-{}", data_name, metric_name, alg.variant_name())); + for (alg_1, alg_2, alg_3) in &algs { + let mut group = c.benchmark_group(format!("{}-{}", ball_data.1.name(), alg_1.name())); group .sample_size(10) .sampling_mode(SamplingMode::Flat) .throughput(Throughput::Elements(queries.len().as_u64())); - for &k in ks { - let alg = alg.with_params(U::ZERO, k); - - if !par_only { - group.bench_with_input(BenchmarkId::new("Ball", k), &k, |b, _| { - b.iter_with_large_drop(|| ball.batch_search(data, queries, alg)); - }); - group.bench_with_input(BenchmarkId::new("BalancedBall", k), &k, |b, _| { - b.iter_with_large_drop(|| bal_ball.batch_search(bal_data, queries, alg)); - }); - - group.bench_with_input(BenchmarkId::new("OffBall", k), &k, |b, _| { - b.iter_with_large_drop(|| off_ball.batch_search(perm_data, queries, alg)); - }); - group.bench_with_input(BenchmarkId::new("BalancedOffBall", k), &k, |b, _| { - b.iter_with_large_drop(|| bal_off_ball.batch_search(bal_perm_data, queries, alg)); - }); - - if let Some((dec_root, dec_data)) = dec_ball_data { - group.bench_with_input(BenchmarkId::new("SquishyBall", k), &k, |b, _| { - b.iter_with_large_drop(|| dec_root.batch_search(dec_data, queries, alg)); - }); - } - if let Some((dec_root, dec_data)) = bal_dec_ball_data { - group.bench_with_input(BenchmarkId::new("BalancedSquishyBall", k), &k, |b, _| { - b.iter_with_large_drop(|| dec_root.batch_search(dec_data, queries, alg)); - }); - } - } - - group.bench_with_input(BenchmarkId::new("ParBall", k), &k, |b, _| { - b.iter_with_large_drop(|| ball.par_batch_search(data, queries, alg)); - }); - group.bench_with_input(BenchmarkId::new("ParBalancedBall", k), &k, |b, _| { - b.iter_with_large_drop(|| bal_ball.par_batch_search(bal_data, queries, alg)); - }); - - group.bench_with_input(BenchmarkId::new("ParOffBall", k), &k, |b, _| { - b.iter_with_large_drop(|| off_ball.par_batch_search(perm_data, queries, alg)); - }); - group.bench_with_input(BenchmarkId::new("ParBalancedOffBall", k), &k, |b, _| { - b.iter_with_large_drop(|| bal_off_ball.par_batch_search(bal_perm_data, queries, alg)); - }); - - if let Some((dec_root, dec_data)) = dec_ball_data { - group.bench_with_input(BenchmarkId::new("ParSquishyBall", k), &k, |b, _| { - b.iter_with_large_drop(|| dec_root.par_batch_search(dec_data, queries, alg)); - }); - } - if let Some((dec_root, dec_data)) = bal_dec_ball_data { - group.bench_with_input(BenchmarkId::new("ParBalancedSquishyBall", k), &k, |b, _| { - b.iter_with_large_drop(|| dec_root.par_batch_search(dec_data, queries, alg)); - }); - } + if !par_only { + bench_cakes( + &mut group, + alg_1, + alg_2, + alg_3.as_ref(), + metric, + queries, + ball_data, + balanced_ball_data, + perm_ball_data, + perm_balanced_ball_data, + dec_ball_data, + dec_balanced_ball_data, + ); } + + par_bench_cakes( + &mut group, + alg_1, + alg_2, + alg_3.as_ref(), + metric, + queries, + ball_data, + balanced_ball_data, + perm_ball_data, + perm_balanced_ball_data, + dec_ball_data, + dec_balanced_ball_data, + ); group.finish(); } } + +fn bench_cakes( + group: &mut BenchmarkGroup, + alg_1: &A1, + alg_2: &A2, + alg_3: Option<&A3>, + metric: &M, + queries: &[I], + (ball, data): (&Ball, &FlatVec)>), + (balanced_ball, balanced_data): (&Ball, &FlatVec)>), + (perm_ball, perm_data): (&PermutedBall>, &FlatVec)>), + (perm_balanced_ball, perm_balanced_data): (&PermutedBall>, &FlatVec)>), + dec_ball_data: Option<(&SquishyBall>, &CodecData)>)>, + dec_balanced_ball_data: Option<(&SquishyBall>, &CodecData)>)>, +) where + I: Encodable + Decodable, + T: Number, + M: Metric, + A1: SearchAlgorithm, M, FlatVec)>>, + A2: SearchAlgorithm>, M, FlatVec)>>, + A3: SearchAlgorithm>, M, CodecData)>>, +{ + let parameter = if let Some(k) = alg_1.k() { + k + } else if let Some(radius) = alg_1.radius() { + (ball.radius() / radius).as_usize() + } else { + 0 + }; + + group.bench_with_input(BenchmarkId::new("Ball", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_1.batch_search(data, metric, ball, queries)); + }); + + group.bench_with_input(BenchmarkId::new("BalancedBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_1.batch_search(balanced_data, metric, balanced_ball, queries)); + }); + + group.bench_with_input(BenchmarkId::new("PermBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_2.batch_search(perm_data, metric, perm_ball, queries)); + }); + + group.bench_with_input(BenchmarkId::new("PermBalancedBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_2.batch_search(perm_balanced_data, metric, perm_balanced_ball, queries)); + }); + + if let Some((dec_root, dec_data)) = dec_ball_data { + group.bench_with_input(BenchmarkId::new("SquishyBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_3.map(|alg_3| alg_3.batch_search(dec_data, metric, dec_root, queries))); + }); + } + + if let Some((dec_root, dec_data)) = dec_balanced_ball_data { + group.bench_with_input(BenchmarkId::new("BalancedSquishyBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_3.map(|alg_3| alg_3.batch_search(dec_data, metric, dec_root, queries))); + }); + } +} + +fn par_bench_cakes( + group: &mut BenchmarkGroup, + alg_1: &A1, + alg_2: &A2, + alg_3: Option<&A3>, + metric: &M, + queries: &[I], + (ball, data): (&Ball, &FlatVec)>), + (balanced_ball, balanced_data): (&Ball, &FlatVec)>), + (perm_ball, perm_data): (&PermutedBall>, &FlatVec)>), + (perm_balanced_ball, perm_balanced_data): (&PermutedBall>, &FlatVec)>), + dec_ball_data: Option<(&SquishyBall>, &CodecData)>)>, + dec_balanced_ball_data: Option<(&SquishyBall>, &CodecData)>)>, +) where + I: Encodable + Decodable + Send + Sync, + T: Number, + M: ParMetric, + A1: ParSearchAlgorithm, M, FlatVec)>>, + A2: ParSearchAlgorithm>, M, FlatVec)>>, + A3: ParSearchAlgorithm>, M, CodecData)>>, + Me: Send + Sync, +{ + let parameter = if let Some(k) = alg_1.k() { + k + } else if let Some(radius) = alg_1.radius() { + (ball.radius() / radius).as_usize() + } else { + 0 + }; + + group.bench_with_input(BenchmarkId::new("ParBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_1.par_batch_search(data, metric, ball, queries)); + }); + + group.bench_with_input(BenchmarkId::new("ParBalancedBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_1.par_batch_search(balanced_data, metric, balanced_ball, queries)); + }); + + group.bench_with_input(BenchmarkId::new("ParPermBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_2.par_batch_search(perm_data, metric, perm_ball, queries)); + }); + + group.bench_with_input(BenchmarkId::new("ParPermBalancedBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_2.par_batch_search(perm_balanced_data, metric, perm_balanced_ball, queries)); + }); + + if let Some((dec_root, dec_data)) = dec_ball_data { + group.bench_with_input(BenchmarkId::new("ParSquishyBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_3.map(|alg_3| alg_3.par_batch_search(dec_data, metric, dec_root, queries))); + }); + } + + if let Some((dec_root, dec_data)) = dec_balanced_ball_data { + group.bench_with_input(BenchmarkId::new("ParBalancedSquishyBall", parameter), &0, |b, _| { + b.iter_with_large_drop(|| alg_3.map(|alg_3| alg_3.par_batch_search(dec_data, metric, dec_root, queries))); + }); + } +} diff --git a/crates/abd-clam/benches/utils/mod.rs b/crates/abd-clam/benches/utils/mod.rs index 385ddd285..a6d8e2657 100644 --- a/crates/abd-clam/benches/utils/mod.rs +++ b/crates/abd-clam/benches/utils/mod.rs @@ -2,4 +2,65 @@ mod compare_permuted; +use abd_clam::pancakes::{Decodable, Encodable}; pub use compare_permuted::compare_permuted; +use distances::number::Float; + +/// A row of a tabular dataset. +#[derive(Clone)] +pub struct Row(Vec); + +impl Row { + /// Converts a `Row` to a vector. + #[allow(dead_code)] + pub fn to_vec(v: &Self) -> Vec { + v.0.clone() + } +} + +impl From> for Row { + fn from(v: Vec) -> Self { + Self(v) + } +} + +impl FromIterator for Row { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl AsRef<[F]> for Row { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + +impl Encodable for Row { + fn as_bytes(&self) -> Box<[u8]> { + self.0 + .iter() + .flat_map(|v| v.to_le_bytes()) + .collect::>() + .into_boxed_slice() + } + + fn encode(&self, reference: &Self) -> Box<[u8]> { + let diffs = reference.0.iter().zip(self.0.iter()).map(|(&a, &b)| a - b).collect(); + Self::as_bytes(&diffs) + } +} + +impl Decodable for Row { + fn from_bytes(bytes: &[u8]) -> Self { + bytes + .chunks_exact(std::mem::size_of::()) + .map(F::from_le_bytes) + .collect() + } + + fn decode(reference: &Self, bytes: &[u8]) -> Self { + let diffs = Self::from_bytes(bytes); + reference.0.iter().zip(diffs.0.iter()).map(|(&a, &b)| a - b).collect() + } +} diff --git a/crates/abd-clam/benches/vector_search.rs b/crates/abd-clam/benches/vector_search.rs index 559f4a06d..2fe637415 100644 --- a/crates/abd-clam/benches/vector_search.rs +++ b/crates/abd-clam/benches/vector_search.rs @@ -1,22 +1,97 @@ //! Benchmark for vector search. -mod utils; +use std::collections::HashMap; use abd_clam::{ - adapter::{Adapter, ParAdapter, ParBallAdapter}, - cakes::OffBall, - partition::ParPartition, - BalancedBall, Ball, Cluster, FlatVec, Metric, Permutable, + cakes::{HintedDataset, PermutedBall}, + cluster::{adapter::ParBallAdapter, BalancedBall, ParPartition}, + dataset::AssociatesMetadataMut, + metric::{Euclidean, ParMetric}, + Ball, Cluster, Dataset, FlatVec, Metric, }; use criterion::*; use rand::prelude::*; +use utils::Row; + +mod utils; + +/// The Euclidean metric using SIMD instructions. +pub struct EuclideanSimd; + +impl> Metric for EuclideanSimd { + fn distance(&self, a: &I, b: &I) -> f32 { + distances::simd::euclidean_f32(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "euclidean-simd" + } + + fn has_identity(&self) -> bool { + true + } -const METRICS: &[(&str, fn(&Vec, &Vec) -> f32)] = &[ - ("euclidean", |x: &Vec<_>, y: &Vec<_>| { - distances::vectors::euclidean(x, y) - }), - ("cosine", |x: &Vec<_>, y: &Vec<_>| distances::vectors::cosine(x, y)), -]; + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl + Send + Sync> ParMetric for EuclideanSimd {} + +fn run_search, f32>>( + c: &mut Criterion, + data: FlatVec, usize>, + metric: &M, + queries: &[Row], + radii: &[f32], + ks: &[usize], + seed: Option, +) { + let criteria = |c: &Ball<_>| c.cardinality() > 1; + let ball = Ball::par_new_tree(&data, metric, &criteria, seed); + let radii = radii.iter().map(|&r| r * ball.radius()).collect::>(); + + let criteria = |c: &BalancedBall<_>| c.cardinality() > 1; + let balanced_ball = BalancedBall::par_new_tree(&data, metric, &criteria, seed).into_ball(); + + let (_, max_radius) = abd_clam::utils::arg_max(&radii).unwrap(); + let (_, max_k) = abd_clam::utils::arg_max(ks).unwrap(); + let data = data + .transform_metadata(|&i| (i, HashMap::new())) + .with_hints_from_tree(&ball, metric) + .with_hints_from(metric, &balanced_ball, max_radius, max_k); + + let (perm_ball, perm_data) = PermutedBall::par_from_ball_tree(ball.clone(), data.clone(), metric); + let (perm_balanced_ball, perm_balanced_data) = + PermutedBall::par_from_ball_tree(balanced_ball.clone(), data.clone(), metric); + + utils::compare_permuted( + c, + metric, + (&ball, &data), + (&balanced_ball, &perm_balanced_data), + (&perm_ball, &perm_data), + (&perm_balanced_ball, &perm_balanced_data), + None, + None, + queries, + &radii, + ks, + true, + ); +} fn vector_search(c: &mut Criterion) { let cardinality = 1_000_000; @@ -24,7 +99,10 @@ fn vector_search(c: &mut Criterion) { let max_val = 1.0; let min_val = -max_val; let seed = 42; - let rows = symagen::random_data::random_tabular_seedable(cardinality, dimensionality, min_val, max_val, seed); + let rows = symagen::random_data::random_tabular_seedable(cardinality, dimensionality, min_val, max_val, seed) + .into_iter() + .map(Row::from) + .collect::>(); let num_queries = 30; let queries = { @@ -38,43 +116,15 @@ fn vector_search(c: &mut Criterion) { .collect::>() }; + let data = FlatVec::new(rows) + .unwrap_or_else(|e| unreachable!("{e}")) + .with_name("vector-search"); let seed = Some(seed); - let radii = vec![0.001, 0.005, 0.01, 0.1]; + let radii = vec![0.001, 0.01]; let ks = vec![1, 10, 100]; - for &(metric_name, distance_fn) in METRICS { - let metric = Metric::new(distance_fn, true); - let data = FlatVec::new(rows.clone(), metric).unwrap(); - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let ball = Ball::par_new_tree(&data, &criteria, seed); - let (off_ball, perm_data) = OffBall::par_from_ball_tree(ball.clone(), data.clone()); - - let criteria = |c: &BalancedBall<_, _, _>| c.cardinality() > 1; - let balanced_ball = BalancedBall::par_new_tree(&data, &criteria, seed); - let (balanced_off_ball, balanced_perm_data) = { - let balanced_off_ball = OffBall::par_adapt_tree(balanced_ball.clone(), None); - let mut balanced_perm_data = data.clone(); - let permutation = balanced_off_ball.source().indices().collect::>(); - balanced_perm_data.permute(&permutation); - (balanced_off_ball, balanced_perm_data) - }; - - utils::compare_permuted( - c, - "vector-search", - metric_name, - (&ball, &data), - (&off_ball, &perm_data), - None, - (&balanced_ball, &data), - (&balanced_off_ball, &balanced_perm_data), - None, - &queries, - &radii, - &ks, - false, - ); - } + + run_search(c, data, &Euclidean, &queries, &radii, &ks, seed); + // run_search(c, &data, &EuclideanSimd, &queries, &radii, &ks, seed); } criterion_group!(benches, vector_search); diff --git a/crates/abd-clam/src/cakes/cluster/mod.rs b/crates/abd-clam/src/cakes/cluster/mod.rs index ddb3cad1a..c5d8ee1cb 100644 --- a/crates/abd-clam/src/cakes/cluster/mod.rs +++ b/crates/abd-clam/src/cakes/cluster/mod.rs @@ -1,21 +1,5 @@ //! An adaptation of `Ball` that stores indices after reordering the dataset. -mod offset_ball; -mod searchable; +mod permuted_ball; -use distances::Number; -pub use offset_ball::OffBall; -pub use searchable::{ParSearchable, Searchable}; - -use crate::{cluster::ParCluster, dataset::ParDataset, BalancedBall, Ball, Cluster, Dataset}; - -impl> Searchable for Ball {} -impl> Searchable for BalancedBall {} -impl, S: Cluster> Searchable for OffBall {} - -impl> ParSearchable for Ball {} -impl> ParSearchable for BalancedBall {} -impl, S: ParCluster> ParSearchable - for OffBall -{ -} +pub use permuted_ball::{Offset, PermutedBall}; diff --git a/crates/abd-clam/src/cakes/cluster/offset_ball.rs b/crates/abd-clam/src/cakes/cluster/offset_ball.rs deleted file mode 100644 index ced66af0d..000000000 --- a/crates/abd-clam/src/cakes/cluster/offset_ball.rs +++ /dev/null @@ -1,329 +0,0 @@ -//! An adaptation of `Ball` that stores indices after reordering the dataset. - -use core::fmt::Debug; -use std::marker::PhantomData; - -use distances::Number; -use serde::{Deserialize, Serialize}; - -use crate::{ - adapter::{Adapter, BallAdapter, ParAdapter, ParBallAdapter, ParParams, Params}, - cluster::ParCluster, - dataset::ParDataset, - Ball, Cluster, Dataset, Permutable, -}; - -/// A variant of `Ball` that stores indices after reordering the dataset. -#[derive(Clone, Serialize, Deserialize)] -pub struct OffBall, S: Cluster> { - /// The `Cluster` type that the `OffsetBall` is based on. - source: S, - /// The children of the `Cluster`. - children: Vec<(usize, U, Box)>, - /// The parameters of the `Cluster`. - params: Offset, - /// Phantom data to satisfy the compiler. - _id: PhantomData<(I, D)>, -} - -impl, S: Cluster + Debug> Debug for OffBall { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("OffBall") - .field("source", &self.source) - .field("children", &!self.children.is_empty()) - .field("offset", &self.params.offset) - .finish() - } -} - -impl, S: Cluster> OffBall { - /// Returns the offset of the `Cluster`. - #[must_use] - pub const fn offset(&self) -> usize { - self.params.offset - } -} - -impl + Permutable> BallAdapter for OffBall> { - /// Creates a new `OffsetBall` tree from a `Ball` tree. - fn from_ball_tree(ball: Ball, mut data: D) -> (Self, D) { - let mut root = Self::adapt_tree_iterative(ball, None); - data.permute(&root.source.indices); - root.source.clear_indices(); - (root, data) - } -} - -impl + Permutable> ParBallAdapter - for OffBall> -{ - /// Creates a new `OffsetBall` tree from a `Ball` tree. - fn par_from_ball_tree(ball: Ball, mut data: D) -> (Self, D) { - let mut root = Self::par_adapt_tree_iterative(ball, None); - data.permute(&root.source.indices); - root.source.clear_indices(); - (root, data) - } -} - -impl + Permutable, S: Cluster> Adapter - for OffBall -{ - fn new_adapted(source: S, children: Vec<(usize, U, Box)>, params: Offset) -> Self { - Self { - source, - children, - params, - _id: PhantomData, - } - } - - fn post_traversal(&mut self) { - // Update the indices of the important instances in the `Cluster`. - let offset = self.params.offset; - let indices = self.source.indices().collect::>(); - self.set_arg_center(new_index(self.source.arg_center(), &indices, offset)); - self.set_arg_radial(new_index(self.source.arg_radial(), &indices, offset)); - for (p, _, _) in self.children_mut() { - *p = new_index(*p, &indices, offset); - } - } - - fn source(&self) -> &S { - &self.source - } - - fn source_mut(&mut self) -> &mut S { - &mut self.source - } - - fn take_source(self) -> S { - self.source - } - - fn params(&self) -> &Offset { - &self.params - } -} - -/// Helper for computing a new index after permutation of data. -fn new_index(i: usize, indices: &[usize], offset: usize) -> usize { - offset - + indices - .iter() - .position(|x| *x == i) - .unwrap_or_else(|| unreachable!("This is a private function and we always pass a valid item.")) -} - -impl + Permutable, S: ParCluster> - ParAdapter for OffBall -{ - fn par_post_traversal(&mut self) { - self.post_traversal(); - } -} - -/// Parameters for the `OffsetBall`. -#[derive(Debug, Default, Copy, Clone, Serialize, Deserialize)] -pub struct Offset { - /// The offset of the slice of indices of the `Cluster` in the reordered - /// dataset. - offset: usize, -} - -impl, S: Cluster> Params for Offset { - fn child_params(&self, children: &[S]) -> Vec { - let mut offset = self.offset; - children - .iter() - .map(|child| { - let params = Self { offset }; - offset += child.cardinality(); - params - }) - .collect() - } -} - -impl, S: ParCluster> ParParams for Offset { - fn par_child_params(&self, children: &[S]) -> Vec { - // Since we need to keep track of the offset, we cannot parallelize this. - self.child_params(children) - } -} - -impl, S: Cluster> Cluster for OffBall { - fn depth(&self) -> usize { - self.source.depth() - } - - fn cardinality(&self) -> usize { - self.source.cardinality() - } - - fn arg_center(&self) -> usize { - self.source.arg_center() - } - - fn set_arg_center(&mut self, arg_center: usize) { - self.source.set_arg_center(arg_center); - } - - fn radius(&self) -> U { - self.source.radius() - } - - fn arg_radial(&self) -> usize { - self.source.arg_radial() - } - - fn set_arg_radial(&mut self, arg_radial: usize) { - self.source.set_arg_radial(arg_radial); - } - - fn lfd(&self) -> f32 { - self.source.lfd() - } - - fn indices(&self) -> impl Iterator + '_ { - self.params.offset..(self.params.offset + self.cardinality()) - } - - fn set_indices(&mut self, indices: Vec) { - let offset = indices[0]; - self.params.offset = offset; - } - - fn children(&self) -> &[(usize, U, Box)] { - self.children.as_slice() - } - - fn children_mut(&mut self) -> &mut [(usize, U, Box)] { - self.children.as_mut_slice() - } - - fn set_children(&mut self, children: Vec<(usize, U, Box)>) { - self.children = children; - } - - fn take_children(&mut self) -> Vec<(usize, U, Box)> { - std::mem::take(&mut self.children) - } - - fn distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)> { - data.query_to_many(query, &self.indices().collect::>()) - } - - fn is_descendant_of(&self, other: &Self) -> bool { - let range = other.params.offset..(other.params.offset + other.cardinality()); - range.contains(&self.offset()) && self.cardinality() <= other.cardinality() - } -} - -impl, S: Cluster> PartialEq for OffBall { - fn eq(&self, other: &Self) -> bool { - self.params.offset == other.params.offset && self.cardinality() == other.cardinality() - } -} - -impl, S: Cluster> Eq for OffBall {} - -impl, S: Cluster> PartialOrd for OffBall { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl, S: Cluster> Ord for OffBall { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.params - .offset - .cmp(&other.params.offset) - .then_with(|| other.cardinality().cmp(&self.cardinality())) - } -} - -impl, S: Cluster> std::hash::Hash for OffBall { - fn hash(&self, state: &mut H) { - (self.params.offset, self.cardinality()).hash(state); - } -} - -impl, S: ParCluster> ParCluster - for OffBall -{ - fn par_distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)> { - data.par_query_to_many(query, &self.indices().collect::>()) - } -} - -#[cfg(feature = "csv")] -impl, S: crate::cluster::WriteCsv> crate::cluster::WriteCsv - for OffBall -{ - fn header(&self) -> Vec { - let mut header = self.source.header(); - header.push("offset".to_string()); - header - } - - fn row(&self) -> Vec { - let mut row = self.source.row(); - row.pop(); - row.extend(vec![ - self.children.is_empty().to_string(), - self.params.offset.to_string(), - ]); - row - } -} - -#[cfg(test)] -mod tests { - use crate::{ - adapter::{BallAdapter, ParBallAdapter}, - Ball, Cluster, Dataset, FlatVec, Metric, Partition, - }; - - use super::OffBall; - - type Fv = FlatVec, i32, usize>; - type B = Ball, i32, Fv>; - type Ob = OffBall, i32, Fv, B>; - - fn gen_tiny_data() -> Result, i32, usize>, String> { - let instances = vec![vec![1, 2], vec![3, 4], vec![5, 6], vec![7, 8], vec![11, 12]]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - FlatVec::new_array(instances.clone(), metric) - } - - fn check_permutation(root: &Ob, data: &FlatVec, i32, usize>) -> bool { - assert!(!root.children().is_empty()); - - for cluster in root.subtree() { - let radius = data.one_to_one(cluster.arg_center(), cluster.arg_radial()); - assert_eq!(cluster.radius(), radius); - } - - true - } - - #[test] - fn permutation() -> Result<(), String> { - let data = gen_tiny_data()?; - - let seed = Some(42); - let criteria = |c: &B| c.depth() < 1; - - let ball = Ball::new_tree(&data, &criteria, seed); - - let (root, perm_data) = OffBall::from_ball_tree(ball.clone(), data.clone()); - assert!(check_permutation(&root, &perm_data)); - - let (root, perm_data) = OffBall::par_from_ball_tree(ball, data); - assert!(check_permutation(&root, &perm_data)); - - Ok(()) - } -} diff --git a/crates/abd-clam/src/cakes/cluster/permuted_ball.rs b/crates/abd-clam/src/cakes/cluster/permuted_ball.rs new file mode 100644 index 000000000..1ee4d8285 --- /dev/null +++ b/crates/abd-clam/src/cakes/cluster/permuted_ball.rs @@ -0,0 +1,333 @@ +//! An adaptation of `Ball` that stores indices after reordering the dataset. + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cluster::{ + adapter::{Adapter, BallAdapter, ParAdapter, ParBallAdapter, ParParams, Params}, + ParCluster, + }, + dataset::{ParDataset, Permutable}, + metric::ParMetric, + Ball, Cluster, Dataset, Metric, +}; + +/// A `Cluster` that stores indices after reordering the dataset. +/// +/// # Type parameters +/// +/// - `T`: The type of the distance values. +/// - `S`: The `Cluster` type that the `PermutedBall` is based on. +#[derive(Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +#[cfg_attr(feature = "disk-io", bitcode(recursive))] +pub struct PermutedBall> { + /// The `Cluster` type that the `PermutedBall` is based on. + source: S, + /// The children of the `Cluster`. + children: Vec>, + /// The parameters of the `Cluster`. + params: Offset, + /// Ghosts in the machine. + phantom: core::marker::PhantomData, +} + +impl> PermutedBall { + /// Clears the indices of the source `Cluster` and its children. + pub fn clear_source_indices(&mut self) { + self.source.clear_indices(); + if !self.is_leaf() { + self.children_mut().into_iter().for_each(Self::clear_source_indices); + } + } + + /// Returns an iterator over the indices. + pub fn iter_indices(&self) -> impl Iterator { + self.params.offset..(self.params.offset + self.cardinality()) + } +} + +impl + core::fmt::Debug> core::fmt::Debug for PermutedBall { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("PermutedBall") + .field("source", &self.source) + .field("offset", &self.params.offset) + .field("children", &!self.children.is_empty()) + .finish() + } +} + +impl> PartialEq for PermutedBall { + fn eq(&self, other: &Self) -> bool { + self.params.offset == other.params.offset && self.cardinality() == other.cardinality() + } +} + +impl> Eq for PermutedBall {} + +impl> PartialOrd for PermutedBall { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl> Ord for PermutedBall { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.params + .offset + .cmp(&other.params.offset) + .then_with(|| other.cardinality().cmp(&self.cardinality())) + } +} + +impl> std::hash::Hash for PermutedBall { + fn hash(&self, state: &mut H) { + (self.params.offset, self.cardinality()).hash(state); + } +} + +impl> PermutedBall { + /// Returns the offset of the `Cluster`. + #[must_use] + pub const fn offset(&self) -> usize { + self.params.offset + } +} + +impl> Cluster for PermutedBall { + fn depth(&self) -> usize { + self.source.depth() + } + + fn cardinality(&self) -> usize { + self.source.cardinality() + } + + fn arg_center(&self) -> usize { + self.source.arg_center() + } + + fn set_arg_center(&mut self, arg_center: usize) { + self.source.set_arg_center(arg_center); + } + + fn radius(&self) -> T { + self.source.radius() + } + + fn arg_radial(&self) -> usize { + self.source.arg_radial() + } + + fn set_arg_radial(&mut self, arg_radial: usize) { + self.source.set_arg_radial(arg_radial); + } + + fn lfd(&self) -> f32 { + self.source.lfd() + } + + fn contains(&self, index: usize) -> bool { + (self.params.offset..(self.params.offset + self.cardinality())).contains(&index) + } + + fn indices(&self) -> Vec { + self.iter_indices().collect() + } + + fn set_indices(&mut self, indices: &[usize]) { + let offset = indices[0]; + self.params.offset = offset; + } + + fn extents(&self) -> &[(usize, T)] { + self.source.extents() + } + + fn extents_mut(&mut self) -> &mut [(usize, T)] { + self.source.extents_mut() + } + + fn add_extent(&mut self, idx: usize, extent: T) { + self.source.add_extent(idx, extent); + } + + fn take_extents(&mut self) -> Vec<(usize, T)> { + self.source.take_extents() + } + + fn children(&self) -> Vec<&Self> { + self.children.iter().map(AsRef::as_ref).collect() + } + + fn children_mut(&mut self) -> Vec<&mut Self> { + self.children.iter_mut().map(AsMut::as_mut).collect() + } + + fn set_children(&mut self, children: Vec>) { + self.children = children; + } + + fn take_children(&mut self) -> Vec> { + std::mem::take(&mut self.children) + } + + fn is_descendant_of(&self, other: &Self) -> bool { + let range = other.params.offset..(other.params.offset + other.cardinality()); + range.contains(&self.offset()) && self.cardinality() <= other.cardinality() + } +} + +impl> ParCluster for PermutedBall { + fn par_indices(&self) -> impl ParallelIterator { + (self.params.offset..(self.params.offset + self.cardinality())).into_par_iter() + } +} + +/// Parameters for adapting the `PermutedBall`. +#[derive(Debug, Default, Copy, Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +pub struct Offset { + /// The offset of the slice of indices of the `Cluster` in the reordered + /// dataset. + offset: usize, +} + +impl, S: Cluster> Params for Offset { + fn child_params>(&self, children: &[S], _: &D, _: &M) -> Vec { + let mut offset = self.offset; + children + .iter() + .map(|child| { + let params = Self { offset }; + offset += child.cardinality(); + params + }) + .collect() + } +} + +impl, S: ParCluster> ParParams for Offset { + fn par_child_params>(&self, children: &[S], data: &D, metric: &M) -> Vec { + // Since we need to keep track of the offset, we cannot parallelize this. + self.child_params(children, data, metric) + } +} + +impl + Permutable> BallAdapter for PermutedBall> { + /// Creates a new `PermutedBall` tree from a `Ball` tree. + fn from_ball_tree>(ball: Ball, mut data: D, metric: &M) -> (Self, D) { + let mut root = Self::adapt_tree_iterative(ball, None, &data, metric); + data.permute(&root.source.indices()); + root.clear_source_indices(); + (root, data) + } +} + +impl + Permutable> ParBallAdapter + for PermutedBall> +{ + /// Creates a new `PermutedBall` tree from a `Ball` tree. + fn par_from_ball_tree>(ball: Ball, mut data: D, metric: &M) -> (Self, D) { + let mut root = Self::par_adapt_tree_iterative(ball, None, &data, metric); + data.permute(&root.source.indices()); + root.clear_source_indices(); + (root, data) + } +} + +impl + Permutable, S: Cluster> Adapter for PermutedBall { + fn new_adapted>(source: S, children: Vec>, params: Offset, _: &D, _: &M) -> Self { + Self { + source, + params, + children, + phantom: core::marker::PhantomData, + } + } + + fn post_traversal(&mut self) { + // Update the indices of the important items in the `Cluster`. + let offset = self.params.offset; + let indices = self.source.indices(); + self.set_arg_center(new_index(self.source.arg_center(), &indices, offset)); + self.set_arg_radial(new_index(self.source.arg_radial(), &indices, offset)); + for (i, _) in self.extents_mut() { + *i = new_index(*i, &indices, offset); + } + } + + fn source(&self) -> &S { + &self.source + } + + fn source_mut(&mut self) -> &mut S { + &mut self.source + } + + fn take_source(self) -> S { + self.source + } + + fn params(&self) -> &Offset { + &self.params + } +} + +/// Helper for computing a new index after permutation of data. +fn new_index(i: usize, indices: &[usize], offset: usize) -> usize { + offset + + indices + .iter() + .position(|x| *x == i) + .unwrap_or_else(|| unreachable!("This is a private function and we always pass a valid item.")) +} + +impl + Permutable, S: ParCluster> ParAdapter + for PermutedBall +{ + fn par_new_adapted>( + source: S, + children: Vec>, + params: Offset, + data: &D, + metric: &M, + ) -> Self { + Self::new_adapted(source, children, params, data, metric) + } +} + +#[cfg(feature = "disk-io")] +impl> crate::cluster::Csv for PermutedBall { + fn header(&self) -> Vec { + let mut header = self.source.header(); + header.push("offset".to_string()); + header + } + + fn row(&self) -> Vec { + let mut row = self.source.row(); + row.pop(); + row.extend(vec![ + self.children.is_empty().to_string(), + self.params.offset.to_string(), + ]); + row + } +} + +#[cfg(feature = "disk-io")] +impl> crate::cluster::ParCsv for PermutedBall {} + +#[cfg(feature = "disk-io")] +impl> crate::cluster::ClusterIO for PermutedBall {} + +#[cfg(feature = "disk-io")] +impl> crate::cluster::ParClusterIO for PermutedBall {} diff --git a/crates/abd-clam/src/cakes/cluster/searchable.rs b/crates/abd-clam/src/cakes/cluster/searchable.rs deleted file mode 100644 index 268ef0816..000000000 --- a/crates/abd-clam/src/cakes/cluster/searchable.rs +++ /dev/null @@ -1,49 +0,0 @@ -//! Searchable dataset. - -use distances::Number; -use rayon::prelude::*; - -use crate::{cakes::Algorithm, cluster::ParCluster, dataset::ParDataset, Cluster, Dataset}; - -/// A dataset that can be searched with entropy-scaling algorithms. -pub trait Searchable>: Cluster { - /// Searches the dataset for the `query` instance and returns the - /// indices of and distances to the nearest neighbors. - fn search(&self, data: &D, query: &I, alg: Algorithm) -> Vec<(usize, U)> { - alg.search(data, self, query) - } - - /// Batch version of the `search` method, to search for multiple queries. - fn batch_search(&self, data: &D, queries: &[I], alg: Algorithm) -> Vec> { - queries.iter().map(|query| self.search(data, query, alg)).collect() - } -} - -/// Parallel version of the `Searchable` trait. -#[allow(clippy::module_name_repetitions)] -pub trait ParSearchable>: - Searchable + ParCluster -{ - /// Parallel version of the `search` method. - fn par_search(&self, data: &D, query: &I, alg: Algorithm) -> Vec<(usize, U)> { - alg.par_search(data, self, query) - } - - /// Batch version of the `par_search` method. - fn batch_par_search(&self, data: &D, queries: &[I], alg: Algorithm) -> Vec> { - queries.iter().map(|query| self.par_search(data, query, alg)).collect() - } - - /// Parallel version of the `batch_search` method. - fn par_batch_search(&self, data: &D, queries: &[I], alg: Algorithm) -> Vec> { - queries.par_iter().map(|query| self.search(data, query, alg)).collect() - } - - /// Parallel version of the `batch_par_search` method. - fn par_batch_par_search(&self, data: &D, queries: &[I], alg: Algorithm) -> Vec> { - queries - .par_iter() - .map(|query| self.par_search(data, query, alg)) - .collect() - } -} diff --git a/crates/abd-clam/src/cakes/codec/codec_data.rs b/crates/abd-clam/src/cakes/codec/codec_data.rs deleted file mode 100644 index 252277d14..000000000 --- a/crates/abd-clam/src/cakes/codec/codec_data.rs +++ /dev/null @@ -1,525 +0,0 @@ -//! An implementation of the Compression and Decompression traits. - -use core::marker::PhantomData; -use std::{ - collections::HashMap, - io::{Read, Write}, -}; - -use distances::Number; -use flate2::write::GzEncoder; -use rayon::prelude::*; -use serde::{Deserialize, Serialize}; - -use crate::{ - cluster::ParCluster, - dataset::{metric_space::ParMetricSpace, ParDataset, SizedHeap}, - Cluster, Dataset, FlatVec, Metric, MetricSpace, Permutable, -}; - -use super::{ - compression::ParCompressible, decompression::ParDecompressible, Compressible, Decodable, Decompressible, Encodable, - SquishyBall, -}; - -/// A compressed dataset, that can be partially decompressed for search and -/// other applications. -/// -/// A `CodecData` may only be built from a `Permutable` dataset, after the tree -/// has been built and the instances in the dataset have been permuted. This is -/// necessary for the `get` method to work correctly. Further, it is discouraged -/// to use the `get` method because it can be expensive if the instance being -/// retrieved is not the center of a cluster. -/// -/// # Type Parameters -/// -/// - `I`: The type of the instances in the dataset. -/// - `U`: The type of the numbers in the dataset. -/// - `M`: The type of the metadata associated with the instances. -#[derive(Clone)] -pub struct CodecData { - /// The metric space of the dataset. - pub(crate) metric: Metric, - /// The cardinality of the dataset. - pub(crate) cardinality: usize, - /// A hint for the dimensionality of the dataset. - pub(crate) dimensionality_hint: (usize, Option), - /// The metadata associated with the instances. - pub(crate) metadata: Vec, - /// The permutation of the original dataset. - pub(crate) permutation: Vec, - /// The name of the dataset. - pub(crate) name: String, - /// The centers of the clusters in the dataset. - pub(crate) center_map: HashMap, - /// The byte-slices representing the leaf clusters. - pub(crate) leaf_bytes: Vec<(usize, Box<[u8]>)>, -} - -impl CodecData { - /// Creates a `CodecData` from a compressible dataset and a `SquishyBall` tree. - pub fn from_compressible + Permutable, S: Cluster>( - data: &Co, - root: &SquishyBall, - ) -> Self { - let center_map = root - .subtree() - .into_iter() - .map(Cluster::arg_center) - .map(|i| (i, data.get(i).clone())) - .collect(); - - let leaf_bytes = data - .encode_leaves(root) - .into_iter() - .map(|(leaf, bytes)| (leaf.offset(), bytes)) - .collect(); - - let cardinality = data.cardinality(); - let metric = data.metric().clone(); - let dimensionality_hint = data.dimensionality_hint(); - Self { - metric, - cardinality, - dimensionality_hint, - metadata: (0..cardinality).collect(), - permutation: data.permutation(), - name: format!("CodecData({})", data.name()), - center_map, - leaf_bytes, - } - } -} - -impl CodecData { - /// Creates a `CodecData` from a compressible dataset and a `SquishyBall` tree. - pub fn par_from_compressible + Permutable, S: ParCluster + core::fmt::Debug>( - data: &D, - root: &SquishyBall, - ) -> Self { - let center_map = root - .subtree() - .into_iter() - .map(Cluster::arg_center) - .map(|i| (i, data.get(i).clone())) - .collect(); - - let leaf_bytes = data - .par_encode_leaves(root) - .into_iter() - .map(|(leaf, bytes)| (leaf.offset(), bytes)) - .collect(); - - let cardinality = data.cardinality(); - let metric = data.metric().clone(); - let dimensionality_hint = data.dimensionality_hint(); - Self { - metric, - cardinality, - dimensionality_hint, - metadata: (0..cardinality).collect(), - permutation: data.permutation(), - name: format!("CodecData({})", data.name()), - center_map, - leaf_bytes, - } - } -} - -impl CodecData { - /// Changes the metadata of the dataset. - /// - /// # Parameters - /// - /// - `metadata`: The new metadata to associate with the instances. - /// - /// # Type Parameters - /// - /// - `Me`: The type of the new metadata. - /// - /// # Returns - /// - /// A `CodecData` with the new metadata. - /// - /// # Errors - /// - /// If the length of the metadata vector does not match the cardinality of - /// the dataset. - pub fn with_metadata(self, mut metadata: Vec) -> Result, String> { - if metadata.len() == self.cardinality { - metadata.permute(&self.permutation); - Ok(CodecData { - metric: self.metric, - cardinality: self.cardinality, - dimensionality_hint: self.dimensionality_hint, - metadata, - permutation: self.permutation, - name: self.name, - center_map: self.center_map, - leaf_bytes: self.leaf_bytes, - }) - } else { - Err(format!( - "The length of the metadata vector ({}) does not match the cardinality of the dataset ({}).", - metadata.len(), - self.cardinality - )) - } - } - - /// Returns the metadata associated with the instances in the dataset. - #[must_use] - pub fn metadata(&self) -> &[M] { - &self.metadata - } - - /// Returns the permutation of the original dataset. - #[must_use] - pub fn permutation(&self) -> &[usize] { - &self.permutation - } -} - -impl CodecData { - /// Decompresses the dataset into a `FlatVec`. - #[must_use] - pub fn to_flat_vec(&self) -> FlatVec { - let instances = self - .leaf_bytes - .iter() - .flat_map(|(_, bytes)| self.decode_leaf(bytes.as_ref())) - .collect::>(); - FlatVec { - metric: self.metric.clone(), - instances, - dimensionality_hint: self.dimensionality_hint, - permutation: self.permutation.clone(), - metadata: self.metadata.clone(), - name: format!("FlatVec({})", self.name), - } - } -} - -impl Decompressible for CodecData { - fn centers(&self) -> &HashMap { - &self.center_map - } - - fn leaf_bytes(&self) -> &[(usize, Box<[u8]>)] { - &self.leaf_bytes - } -} - -impl ParDecompressible for CodecData {} - -impl Dataset for CodecData { - fn name(&self) -> &str { - &self.name - } - - fn with_name(mut self, name: &str) -> Self { - self.name = format!("CodecData({name})"); - self - } - - fn cardinality(&self) -> usize { - self.cardinality - } - - fn dimensionality_hint(&self) -> (usize, Option) { - self.dimensionality_hint - } - - #[allow(clippy::panic)] - fn get(&self, index: usize) -> &I { - self.center_map.get(&index).map_or_else( - || panic!("For CodecData, the `get` method may only be used for cluster centers."), - |center| center, - ) - } - - fn knn(&self, query: &I, k: usize) -> Vec<(usize, U)> { - let mut knn = SizedHeap::new(Some(k)); - self.leaf_bytes - .iter() - .map(|(o, bytes)| (*o, self.decode_leaf(bytes.as_ref()))) - .flat_map(|(o, instances)| { - let instances = instances - .iter() - .enumerate() - .map(|(i, p)| (o + i, p)) - .collect::>(); - MetricSpace::one_to_many(self, query, &instances) - }) - .for_each(|(i, d)| knn.push((d, i))); - knn.items().map(|(d, i)| (i, d)).collect() - } - - fn rnn(&self, query: &I, radius: U) -> Vec<(usize, U)> { - self.leaf_bytes - .iter() - .map(|(o, bytes)| (*o, self.decode_leaf(bytes.as_ref()))) - .flat_map(|(o, instances)| { - let instances = instances - .iter() - .enumerate() - .map(|(i, p)| (o + i, p)) - .collect::>(); - MetricSpace::one_to_many(self, query, &instances) - }) - .filter(|&(_, d)| d <= radius) - .collect() - } -} - -impl MetricSpace for CodecData { - fn metric(&self) -> &Metric { - &self.metric - } - - fn set_metric(&mut self, metric: Metric) { - self.metric = metric; - } -} - -impl ParMetricSpace for CodecData {} - -impl ParDataset for CodecData { - fn par_knn(&self, query: &I, k: usize) -> Vec<(usize, U)> { - let mut knn = SizedHeap::new(Some(k)); - self.leaf_bytes - .par_iter() - .map(|(o, bytes)| (*o, self.decode_leaf(bytes.as_ref()))) - .flat_map(|(o, instances)| { - let instances = instances - .iter() - .enumerate() - .map(|(i, p)| (o + i, p)) - .collect::>(); - ParMetricSpace::par_one_to_many(self, query, &instances) - }) - .collect::>() - .into_iter() - .for_each(|(i, d)| knn.push((d, i))); - knn.items().map(|(d, i)| (i, d)).collect() - } - - fn par_rnn(&self, query: &I, radius: U) -> Vec<(usize, U)> { - self.leaf_bytes - .par_iter() - .map(|(o, bytes)| (*o, self.decode_leaf(bytes.as_ref()))) - .flat_map(|(o, instances)| { - let instances = instances - .iter() - .enumerate() - .map(|(i, p)| (o + i, p)) - .collect::>(); - ParMetricSpace::par_one_to_many(self, query, &instances) - }) - .filter(|&(_, d)| d <= radius) - .collect() - } -} - -/// A private helper struct for serializing and deserializing `CodecData`. -#[derive(Serialize, Deserialize)] -struct CodecDataSerde { - /// The cardinality of the dataset. - cardinality: usize, - /// A hint for the dimensionality of the dataset. - dimensionality_hint: (usize, Option), - /// The name of the dataset. - name: String, - /// The bytes for the metadata - metadata: Box<[u8]>, - /// The bytes for the permutation. - permutation: Box<[u8]>, - /// The bytes for the center map. - center_map: Box<[u8]>, - /// The bytes for the leaf bytes. - leaf_bytes: Box<[u8]>, - /// Phantom data. - _p: PhantomData<(I, U, M)>, -} - -impl CodecDataSerde { - /// Creates a `CodecDataSerde` from a `CodecData`. - fn from_codec_data(data: &CodecData) -> Result { - let metadata = data - .metadata - .par_iter() - .flat_map(|m| { - let mut bytes = Vec::new(); - let encoding = m.as_bytes(); - bytes.extend_from_slice(&encoding.len().to_le_bytes()); - bytes.extend_from_slice(&encoding); - bytes - }) - .collect::>(); - - let permutation = data - .permutation - .iter() - .flat_map(|i| i.to_le_bytes()) - .collect::>(); - - let center_map = data - .center_map - .par_iter() - .flat_map(|(i, p)| { - let mut bytes = Vec::new(); - bytes.extend_from_slice(&i.to_le_bytes()); - let encoding = p.as_bytes(); - bytes.extend_from_slice(&encoding.len().to_le_bytes()); - bytes.extend_from_slice(&encoding); - bytes - }) - .collect::>(); - - let leaf_bytes = data - .leaf_bytes - .par_iter() - .flat_map(|(i, encodings)| { - let mut bytes = Vec::new(); - bytes.extend_from_slice(&i.to_le_bytes()); - bytes.extend_from_slice(&encodings.len().to_le_bytes()); - bytes.extend_from_slice(encodings); - bytes - }) - .collect::>(); - - let bytes = [metadata, permutation, center_map, leaf_bytes] - .par_iter() - .map(|bytes| { - let mut encoder = GzEncoder::new(Vec::new(), flate2::Compression::default()); - encoder.write_all(bytes).map_err(|e| e.to_string())?; - encoder.finish().map_err(|e| e.to_string()).map(Vec::into_boxed_slice) - }) - .collect::, _>>()?; - - let [metadata, permutation, center_map, leaf_bytes] = bytes - .try_into() - .map_err(|_| "Failed to convert bytes into array.".to_string())?; - - Ok(Self { - cardinality: data.cardinality, - dimensionality_hint: data.dimensionality_hint, - name: data.name.clone(), - metadata, - permutation, - center_map, - leaf_bytes, - _p: PhantomData, - }) - } -} - -impl CodecData { - /// Creates a `CodecData` from a `CodecDataSerde`. - fn from_serde(data: CodecDataSerde) -> Result { - let bytes = [data.metadata, data.permutation, data.center_map, data.leaf_bytes] - .par_iter() - .map(|bytes| { - let mut decoder = flate2::read::GzDecoder::new(&bytes[..]); - let mut dec_bytes = Vec::new(); - decoder.read_to_end(&mut dec_bytes).map_err(|e| e.to_string())?; - Ok(dec_bytes) - }) - .collect::, String>>()?; - - let [metadata, permutation, center_map, leaf_bytes]: [Vec; 4] = bytes - .try_into() - .map_err(|_| "Failed to convert bytes into array.".to_string())?; - - let (metadata, (permutation, (center_map, leaf_bytes))) = rayon::join( - || Self::decode_metadata(&metadata), - || { - rayon::join( - || Self::decode_permutation(&permutation), - || { - rayon::join( - || Self::decode_center_map(¢er_map), - || Self::decode_leaf_bytes(&leaf_bytes), - ) - }, - ) - }, - ); - - Ok(Self { - metric: Metric::default(), - cardinality: data.cardinality, - dimensionality_hint: data.dimensionality_hint, - metadata, - permutation, - name: data.name, - center_map, - leaf_bytes, - }) - } - - /// Decodes the metadata from the compressed bytes. - fn decode_metadata(bytes: &[u8]) -> Vec { - let mut metadata = Vec::new(); - let mut offset = 0; - while offset < bytes.len() { - let encoding = crate::utils::read_encoding(bytes, &mut offset); - metadata.push(M::from_bytes(&encoding)); - } - metadata - } - - /// Decodes the permutation from the compressed bytes. - fn decode_permutation(bytes: &[u8]) -> Vec { - bytes - .chunks_exact(core::mem::size_of::()) - .map(|chunk| { - let mut array = [0; std::mem::size_of::()]; - array.copy_from_slice(&chunk[..std::mem::size_of::()]); - usize::from_le_bytes(array) - }) - .collect::>() - } - - /// Decodes the center map from the compressed bytes. - fn decode_center_map(bytes: &[u8]) -> HashMap { - let mut center_map = HashMap::new(); - let mut offset = 0; - while offset < bytes.len() { - let i = crate::utils::read_number::(bytes, &mut offset); - let encoding = crate::utils::read_encoding(bytes, &mut offset); - center_map.insert(i, I::from_bytes(&encoding)); - } - center_map - } - - /// Decodes the leaf bytes from the compressed bytes. - fn decode_leaf_bytes(bytes: &[u8]) -> Vec<(usize, Box<[u8]>)> { - let mut leaf_bytes = Vec::new(); - let mut offset = 0; - while offset < bytes.len() { - let i = crate::utils::read_number::(bytes, &mut offset); - let encoding = crate::utils::read_encoding(bytes, &mut offset); - leaf_bytes.push((i, encoding)); - } - leaf_bytes - } -} - -impl Serialize for CodecData { - fn serialize(&self, serializer: S) -> Result { - CodecDataSerde::from_codec_data(self) - .map_err(serde::ser::Error::custom)? - .serialize(serializer) - } -} - -impl<'de, I: Decodable + Send + Sync, U, M: Decodable + Send + Sync> Deserialize<'de> for CodecData { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - CodecDataSerde::deserialize(deserializer) - .and_then(|serde| Self::from_serde(serde).map_err(serde::de::Error::custom)) - } -} diff --git a/crates/abd-clam/src/cakes/codec/mod.rs b/crates/abd-clam/src/cakes/codec/mod.rs deleted file mode 100644 index 00df1243e..000000000 --- a/crates/abd-clam/src/cakes/codec/mod.rs +++ /dev/null @@ -1,265 +0,0 @@ -//! Compression and Decompression with CLAM - -mod codec_data; -mod compression; -mod decompression; -mod squishy_ball; - -use distances::{number::Float, strings::needleman_wunsch, Number}; - -#[allow(clippy::module_name_repetitions)] -pub use codec_data::CodecData; -pub use compression::{Compressible, Encodable, ParCompressible}; -pub use decompression::{Decodable, Decompressible, ParDecompressible}; -pub use squishy_ball::SquishyBall; - -use crate::{cluster::ParCluster, Cluster, FlatVec}; - -impl Compressible for FlatVec {} -impl ParCompressible for FlatVec {} - -impl, Dec: Decompressible, S: Cluster> - super::cluster::Searchable for SquishyBall -{ -} - -impl< - I: Encodable + Decodable + Send + Sync, - U: Number, - Co: ParCompressible, - Dec: ParDecompressible, - S: ParCluster + core::fmt::Debug, - > super::cluster::ParSearchable for SquishyBall -{ -} - -impl Encodable for usize { - fn as_bytes(&self) -> Box<[u8]> { - self.to_le_bytes().to_vec().into_boxed_slice() - } - - fn encode(&self, _: &Self) -> Box<[u8]> { - self.as_bytes() - } -} - -impl Decodable for usize { - fn from_bytes(bytes: &[u8]) -> Self { - let mut array = [0; std::mem::size_of::()]; - array.copy_from_slice(&bytes[..std::mem::size_of::()]); - Self::from_le_bytes(array) - } - - fn decode(_: &Self, bytes: &[u8]) -> Self { - Self::from_bytes(bytes) - } -} - -impl Encodable for Vec { - fn as_bytes(&self) -> Box<[u8]> { - self.iter() - .flat_map(|v| v.to_le_bytes()) - .collect::>() - .into_boxed_slice() - } - - fn encode(&self, reference: &Self) -> Box<[u8]> { - let diffs = reference.iter().zip(self.iter()).map(|(&a, &b)| a - b).collect(); - Self::as_bytes(&diffs) - } -} - -impl Decodable for Vec { - fn from_bytes(bytes: &[u8]) -> Self { - bytes - .chunks_exact(std::mem::size_of::()) - .map(F::from_le_bytes) - .collect() - } - - fn decode(reference: &Self, bytes: &[u8]) -> Self { - let diffs = Self::from_bytes(bytes); - reference.iter().zip(diffs).map(|(&a, b)| a - b).collect() - } -} - -/// This uses the Needleman-Wunsch algorithm to encode strings. -impl Encodable for String { - fn as_bytes(&self) -> Box<[u8]> { - self.as_bytes().to_vec().into_boxed_slice() - } - - fn encode(&self, reference: &Self) -> Box<[u8]> { - let penalties = distances::strings::Penalties::default(); - let table = needleman_wunsch::compute_table::(reference, self, penalties); - let (aligned_ref, aligned_tar) = needleman_wunsch::trace_back_recursive(&table, [reference, self]); - let edits = needleman_wunsch::unaligned_x_to_y(&aligned_ref, &aligned_tar); - serialize_edits(&edits) - } -} - -/// This uses the Needleman-Wunsch algorithm to decode strings. -impl Decodable for String { - fn from_bytes(bytes: &[u8]) -> Self { - Self::from_utf8(bytes.to_vec()).unwrap_or_else(|e| unreachable!("Could not cast back to string: {e:?}")) - } - - fn decode(reference: &Self, bytes: &[u8]) -> Self { - let edits = deserialize_edits(bytes); - needleman_wunsch::apply_edits(reference, &edits) - } -} - -/// Serializes a vector of edit operations into a byte array. -fn serialize_edits(edits: &[needleman_wunsch::Edit]) -> Box<[u8]> { - let bytes = edits.iter().flat_map(edit_to_bin).collect::>(); - bytes.into_boxed_slice() -} - -/// Encodes an edit operation into a byte array. -/// -/// A `Del` edit is encoded as `10` followed by the index of the edit in 14 bits. -/// A `Ins` edit is encoded as `01` followed by the index of the edit in 14 bits and the character in 8 bits. -/// A `Sub` edit is encoded as `11` followed by the index of the edit in 14 bits and the character in 8 bits. -/// -/// # Arguments -/// -/// * `edit`: The edit operation. -/// -/// # Returns -/// -/// A byte array encoding the edit operation. -#[allow(clippy::cast_possible_truncation)] -fn edit_to_bin(edit: &needleman_wunsch::Edit) -> Vec { - let mask_idx = 0b00_111111; - let mask_del = 0b10_000000; - let mask_ins = 0b01_000000; - let mask_sub = 0b11_000000; - - match edit { - needleman_wunsch::Edit::Del(i) => { - let mut bytes = (*i as u16).to_be_bytes().to_vec(); - // First 2 bits for the type of edit, 14 bits for the index. - bytes[0] &= mask_idx; - bytes[0] |= mask_del; - bytes - } - needleman_wunsch::Edit::Ins(i, c) => { - let mut bytes = (*i as u16).to_be_bytes().to_vec(); - // First 2 bits for the type of edit, 14 bits for the index. - bytes[0] &= mask_idx; - bytes[0] |= mask_ins; - // 8 bits for the character. - bytes.push(*c as u8); - bytes - } - needleman_wunsch::Edit::Sub(i, c) => { - let mut bytes = (*i as u16).to_be_bytes().to_vec(); - // First 2 bits for the type of edit, 14 bits for the index. - bytes[0] &= mask_idx; - bytes[0] |= mask_sub; - // 8 bits for the character. - bytes.push(*c as u8); - bytes - } - } -} - -/// Deserializes a byte array into a vector of edit operations. -/// -/// A `Del` edit is encoded as `10` followed by the index of the edit in 14 bits. -/// A `Ins` edit is encoded as `01` followed by the index of the edit in 14 bits and the character in 8 bits. -/// A `Sub` edit is encoded as `11` followed by the index of the edit in 14 bits and the character in 8 bits. -/// -/// # Arguments -/// -/// * `bytes`: The byte array encoding the edit operations. -/// -/// # Errors -/// -/// * If the byte array is not a valid encoding of edit operations. -/// * If the edit type is not recognized. -/// -/// # Returns -/// -/// A vector of edit operations. -fn deserialize_edits(bytes: &[u8]) -> Vec { - let mut edits = Vec::new(); - let mut offset = 0; - let mask_idx = 0b00_111111; - - while offset < bytes.len() { - let edit_bits = bytes[offset] & !mask_idx; - let i = u16::from_be_bytes([bytes[offset] & mask_idx, bytes[offset + 1]]) as usize; - let edit = match edit_bits { - 0b10_000000 => { - offset += 2; - needleman_wunsch::Edit::Del(i) - } - 0b01_000000 => { - let c = bytes[offset + 2] as char; - offset += 3; - needleman_wunsch::Edit::Ins(i, c) - } - 0b11_000000 => { - let c = bytes[offset + 2] as char; - offset += 3; - needleman_wunsch::Edit::Sub(i, c) - } - _ => unreachable!("Invalid edit type: {edit_bits:b}."), - }; - edits.push(edit); - } - edits -} - -#[cfg(test)] -pub mod tests { - use crate::{adapter::BallAdapter, cakes::CodecData, Ball, Cluster, FlatVec, MetricSpace, Partition}; - - use super::SquishyBall; - - use crate::cakes::tests::gen_random_data; - - #[test] - fn ser_de() -> Result<(), String> { - // The instances. - type I = Vec; - // The distance values. - type U = f64; - // The compressible dataset - type Co = FlatVec; - // The ball for the compressible dataset. - type B = Ball; - // The decompressible dataset - type Dec = CodecData; - // The squishy ball - type Sb = SquishyBall; - - let seed = 42; - let car = 1_000; - let dim = 10; - - let data: Co = gen_random_data(car, dim, 10.0, seed)?; - let metric = data.metric().clone(); - let metadata = data.metadata().to_vec(); - - let criteria = |c: &B| c.cardinality() > 1; - let ball = B::new_tree(&data, &criteria, Some(seed)); - let (_, co_data) = Sb::from_ball_tree(ball, data); - let co_data = co_data.with_metadata(metadata.clone())?; - - let serialized = bincode::serialize(&co_data).unwrap(); - let mut deserialized = bincode::deserialize::(&serialized).unwrap(); - deserialized.set_metric(metric.clone()); - - assert_eq!(co_data.cardinality, deserialized.cardinality); - assert_eq!(co_data.dimensionality_hint, deserialized.dimensionality_hint); - assert_eq!(co_data.metadata, deserialized.metadata); - assert_eq!(co_data.permutation, deserialized.permutation); - assert_eq!(co_data.center_map, deserialized.center_map); - assert_eq!(co_data.leaf_bytes, deserialized.leaf_bytes); - - Ok(()) - } -} diff --git a/crates/abd-clam/src/cakes/codec/squishy_ball.rs b/crates/abd-clam/src/cakes/codec/squishy_ball.rs deleted file mode 100644 index 801625e0f..000000000 --- a/crates/abd-clam/src/cakes/codec/squishy_ball.rs +++ /dev/null @@ -1,490 +0,0 @@ -//! An adaptation of `Ball` that allows for compression of the dataset and -//! search in the compressed space. - -use core::fmt::Debug; - -use std::marker::PhantomData; - -use distances::Number; -use rayon::prelude::*; -use serde::{Deserialize, Serialize}; - -use crate::{ - adapter::{Adapter, BallAdapter, ParAdapter, ParBallAdapter, ParParams, Params}, - cakes::OffBall, - cluster::ParCluster, - dataset::{metric_space::ParMetricSpace, ParDataset}, - Ball, Cluster, Dataset, MetricSpace, Permutable, -}; - -use super::{ - compression::ParCompressible, decompression::ParDecompressible, CodecData, Compressible, Decodable, Decompressible, - Encodable, -}; - -/// A variant of `Ball` that stores indices after reordering the dataset. -#[derive(Clone, Serialize, Deserialize)] -pub struct SquishyBall< - I: Encodable + Decodable, - U: Number, - Co: Compressible, - Dec: Decompressible, - S: Cluster, -> { - /// The `Cluster` type that the `OffsetBall` is based on. - source: OffBall, - /// The children of the `Cluster`. - children: Vec<(usize, U, Box)>, - /// Parameters for the `OffsetBall`. - costs: SquishCosts, - /// Phantom data to satisfy the compiler. - _dc: PhantomData, -} - -impl< - I: Encodable + Decodable, - U: Number, - Co: Compressible, - Dec: Decompressible, - S: Cluster + Debug, - > Debug for SquishyBall -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SquishyBall") - .field("source", &self.source) - .field("children", &!self.children.is_empty()) - .field("recursive_cost", &self.costs.recursive) - .field("unitary_cost", &self.costs.unitary) - .field("minimum_cost", &self.costs.minimum) - .finish() - } -} - -impl, S: Cluster, M> - SquishyBall, S> -{ - /// Allows for the `SquishyBall` to be use with the same compressed dataset under different metadata types. - pub fn with_metadata_type(self) -> SquishyBall, S> { - SquishyBall { - source: self.source, - children: self - .children - .into_iter() - .map(|(i, r, c)| (i, r, Box::new(c.with_metadata_type()))) - .collect(), - costs: self.costs, - _dc: PhantomData, - } - } -} - -impl + Permutable> - BallAdapter, SquishCosts> - for SquishyBall, Ball> -{ - fn from_ball_tree(ball: Ball, data: D) -> (Self, CodecData) { - let (off_ball, data) = OffBall::from_ball_tree(ball, data); - let mut root = Self::adapt_tree_iterative(off_ball, None); - root.set_costs(&data); - root.trim(4); - let data = CodecData::from_compressible(&data, &root); - (root, data) - } -} - -impl + Permutable> - ParBallAdapter, SquishCosts> - for SquishyBall, Ball> -{ - fn par_from_ball_tree(ball: Ball, data: D) -> (Self, CodecData) { - let (off_ball, data) = OffBall::par_from_ball_tree(ball, data); - let mut root = Self::par_adapt_tree_iterative(off_ball, None); - root.par_set_costs(&data); - root.trim(4); - let data = CodecData::par_from_compressible(&data, &root); - (root, data) - } -} - -impl, Dec: Decompressible, S: Cluster> - SquishyBall -{ - /// Get the unitary cost of the `SquishyBall`. - pub const fn unitary_cost(&self) -> U { - self.costs.unitary - } - - /// Get the recursive cost of the `SquishyBall`. - pub const fn recursive_cost(&self) -> U { - self.costs.recursive - } - - /// Gets the offset of the cluster's indices in its dataset. - pub const fn offset(&self) -> usize { - self.source.offset() - } - - /// Trims the tree by removing empty children of clusters whose unitary cost - /// is greater than the recursive cost. - pub fn trim(&mut self, min_depth: usize) { - if !self.children.is_empty() { - if (self.costs.unitary <= self.costs.recursive) && (self.depth() >= min_depth) { - self.children.clear(); - } else { - self.children.iter_mut().for_each(|(_, _, c)| c.trim(min_depth)); - } - } - } - - /// Sets the costs for the tree. - pub fn set_costs(&mut self, data: &Co) { - self.set_unitary_cost(data); - if self.children.is_empty() { - self.costs.recursive = U::ZERO; - } else { - self.children.iter_mut().for_each(|(_, _, c)| c.set_costs(data)); - self.set_recursive_cost(data); - } - self.set_min_cost(); - } - - /// Calculates the unitary cost of the `Cluster`. - fn set_unitary_cost(&mut self, data: &Co) { - self.costs.unitary = Dataset::one_to_many(data, self.arg_center(), &self.indices().collect::>()) - .into_iter() - .map(|(_, d)| d) - .sum(); - } - - /// Calculates the recursive cost of the `Cluster`. - fn set_recursive_cost(&mut self, data: &Co) { - if self.children.is_empty() { - self.costs.recursive = U::ZERO; - } else { - let children = self.children.iter().map(|(_, _, c)| c.as_ref()).collect::>(); - let child_costs = children.iter().map(|c| c.costs.minimum).sum::(); - let child_centers = children.iter().map(|c| c.arg_center()).collect::>(); - let distances = Dataset::one_to_many(data, self.arg_center(), &child_centers); - - self.costs.recursive = child_costs + distances.into_iter().map(|(_, d)| d).sum::(); - } - } - - /// Sets the minimum cost of the `Cluster`. - fn set_min_cost(&mut self) { - self.costs.minimum = if self.costs.recursive < self.costs.unitary { - self.costs.recursive - } else { - self.costs.unitary - }; - } -} - -impl< - I: Encodable + Decodable + Send + Sync, - U: Number, - Co: ParCompressible, - Dec: ParDecompressible, - S: ParCluster, - > SquishyBall -{ - /// Sets the costs for the tree. - pub fn par_set_costs(&mut self, data: &Co) { - self.par_set_unitary_cost(data); - if self.children.is_empty() { - self.costs.recursive = U::ZERO; - } else { - self.children.par_iter_mut().for_each(|(_, _, c)| c.par_set_costs(data)); - self.par_set_recursive_cost(data); - } - self.set_min_cost(); - } - - /// Calculates the unitary cost of the `Cluster`. - fn par_set_unitary_cost(&mut self, data: &Co) { - self.costs.unitary = ParDataset::par_one_to_many(data, self.arg_center(), &self.indices().collect::>()) - .into_iter() - .map(|(_, d)| d) - .sum(); - } - - /// Calculates the recursive cost of the `Cluster`. - fn par_set_recursive_cost(&mut self, data: &Co) { - if self.children.is_empty() { - self.costs.recursive = U::ZERO; - } else { - let children = self.children.iter().map(|(_, _, c)| c.as_ref()).collect::>(); - let child_costs = children.iter().map(|c| c.costs.minimum).sum::(); - let child_centers = children.iter().map(|c| c.arg_center()).collect::>(); - let distances = ParDataset::par_one_to_many(data, self.arg_center(), &child_centers); - - self.costs.recursive = child_costs + distances.into_iter().map(|(_, d)| d).sum::(); - } - } -} - -impl, Dec: Decompressible, S: Cluster> - Adapter, SquishCosts> for SquishyBall -{ - fn new_adapted(source: OffBall, children: Vec<(usize, U, Box)>, params: SquishCosts) -> Self { - Self { - source, - children, - costs: params, - _dc: PhantomData, - } - } - - fn post_traversal(&mut self) {} - - fn source(&self) -> &OffBall { - &self.source - } - - fn source_mut(&mut self) -> &mut OffBall { - &mut self.source - } - - fn take_source(self) -> OffBall { - self.source - } - - fn params(&self) -> &SquishCosts { - &self.costs - } -} - -/// Parameters for the `OffsetBall`. -#[derive(Debug, Default, Copy, Clone, Serialize, Deserialize)] -pub struct SquishCosts { - /// Expected memory cost of recursive compression. - recursive: U, - /// Expected memory cost of unitary compression. - unitary: U, - /// The minimum expected memory cost of compression. - minimum: U, -} - -impl, Dec: Decompressible, S: Cluster> - Params for SquishCosts -{ - fn child_params(&self, children: &[S]) -> Vec { - children.iter().map(|_| Self::default()).collect() - } -} - -impl< - I: Encodable + Decodable + Send + Sync, - U: Number, - Co: ParCompressible, - Dec: ParDecompressible, - S: ParCluster + Debug, - > ParAdapter, SquishCosts> for SquishyBall -{ - fn par_post_traversal(&mut self) {} -} - -impl< - I: Encodable + Decodable + Send + Sync, - U: Number, - Co: ParCompressible, - Dec: ParDecompressible, - S: ParCluster, - > ParParams for SquishCosts -{ - fn par_child_params(&self, children: &[S]) -> Vec { - Params::::child_params(self, children) - } -} - -impl, Dec: Decompressible, S: Cluster> - Cluster for SquishyBall -{ - fn depth(&self) -> usize { - self.source.depth() - } - - fn cardinality(&self) -> usize { - self.source.cardinality() - } - - fn arg_center(&self) -> usize { - self.source.arg_center() - } - - fn set_arg_center(&mut self, arg_center: usize) { - self.source.set_arg_center(arg_center); - } - - fn radius(&self) -> U { - self.source.radius() - } - - fn arg_radial(&self) -> usize { - self.source.arg_radial() - } - - fn set_arg_radial(&mut self, arg_radial: usize) { - self.source.set_arg_radial(arg_radial); - } - - fn lfd(&self) -> f32 { - self.source.lfd() - } - - fn indices(&self) -> impl Iterator + '_ { - self.source.indices() - } - - fn set_indices(&mut self, indices: Vec) { - self.source.set_indices(indices); - } - - fn children(&self) -> &[(usize, U, Box)] { - &self.children - } - - fn children_mut(&mut self) -> &mut [(usize, U, Box)] { - &mut self.children - } - - fn set_children(&mut self, children: Vec<(usize, U, Box)>) { - self.children = children; - } - - fn take_children(&mut self) -> Vec<(usize, U, Box)> { - std::mem::take(&mut self.children) - } - - fn distances_to_query(&self, data: &Dec, query: &I) -> Vec<(usize, U)> { - let leaf_bytes = data.leaf_bytes(); - - let instances = - self.leaves() - .into_iter() - .map(Self::offset) - .map(|o| { - leaf_bytes.iter().position(|(off, _)| *off == o).unwrap_or_else(|| { - unreachable!("Offset not found in leaf offsets: {}, {:?}", o, data.leaf_bytes()) - }) - }) - .map(|pos| &leaf_bytes[pos]) - .flat_map(|(o, bytes)| { - data.decode_leaf(bytes) - .into_iter() - .enumerate() - .map(|(i, p)| (i + *o, p)) - }) - .collect::>(); - - let instances = instances.iter().map(|(i, p)| (*i, p)).collect::>(); - MetricSpace::one_to_many(data, query, &instances) - } - - fn is_descendant_of(&self, other: &Self) -> bool { - self.source.is_descendant_of(&other.source) - } -} - -impl< - I: Encodable + Decodable + Send + Sync, - U: Number, - Co: ParCompressible, - Dec: ParDecompressible, - S: ParCluster + Debug, - > ParCluster for SquishyBall -{ - fn par_distances_to_query(&self, data: &Dec, query: &I) -> Vec<(usize, U)> { - let leaf_bytes = data.leaf_bytes(); - - let instances = - self.leaves() - .into_par_iter() - .map(Self::offset) - .map(|o| { - leaf_bytes.iter().position(|(off, _)| *off == o).unwrap_or_else(|| { - unreachable!("Offset not found in leaf offsets: {}, {:?}", o, data.leaf_bytes()) - }) - }) - .map(|pos| &leaf_bytes[pos]) - .flat_map(|(o, bytes)| { - data.decode_leaf(bytes) - .into_par_iter() - .enumerate() - .map(|(i, p)| (i + *o, p)) - }) - .collect::>(); - - let instances = instances.iter().map(|(i, p)| (*i, p)).collect::>(); - ParMetricSpace::par_one_to_many(data, query, &instances) - } -} - -impl, Dec: Decompressible, S: Cluster> - PartialEq for SquishyBall -{ - fn eq(&self, other: &Self) -> bool { - self.source == other.source - } -} - -impl, Dec: Decompressible, S: Cluster> Eq - for SquishyBall -{ -} - -impl, Dec: Decompressible, S: Cluster> - PartialOrd for SquishyBall -{ - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl, Dec: Decompressible, S: Cluster> Ord - for SquishyBall -{ - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.source.cmp(&other.source) - } -} - -impl, Dec: Decompressible, S: Cluster> - std::hash::Hash for SquishyBall -{ - fn hash(&self, state: &mut H) { - self.source.hash(state); - } -} - -#[cfg(feature = "csv")] -impl< - I: Encodable + Decodable, - U: Number, - Co: Compressible, - Dec: Decompressible, - S: crate::cluster::WriteCsv, - > crate::cluster::WriteCsv for SquishyBall -{ - fn header(&self) -> Vec { - let mut header = self.source.header(); - header.extend(vec![ - "recursive_cost".to_string(), - "unitary_cost".to_string(), - "minimum_cost".to_string(), - ]); - header - } - - fn row(&self) -> Vec { - let mut row = self.source.row(); - row.pop(); - row.extend(vec![ - self.children.is_empty().to_string(), - self.costs.recursive.to_string(), - self.costs.unitary.to_string(), - self.costs.minimum.to_string(), - ]); - row - } -} diff --git a/crates/abd-clam/src/cakes/dataset/hinted.rs b/crates/abd-clam/src/cakes/dataset/hinted.rs new file mode 100644 index 000000000..336fd13a4 --- /dev/null +++ b/crates/abd-clam/src/cakes/dataset/hinted.rs @@ -0,0 +1,127 @@ +//! `Dataset`s which store extra information to improve search performance. + +use std::collections::HashMap; + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{KnnDepthFirst, ParSearchAlgorithm, RnnClustered, SearchAlgorithm}, + cluster::ParCluster, + metric::ParMetric, + Cluster, Metric, +}; + +use super::{ParSearchable, Searchable}; + +/// An extension of the `Dataset` trait to support hinting for search. +/// +/// Each hint is a tuple of the number of known neighbors and the distance to +/// the farthest known neighbor. Each item may have multiple hints, stored in a +/// map where the key is the number of known neighbors and the value is the +/// distance to the farthest known neighbor. +#[allow(clippy::module_name_repetitions)] +pub trait HintedDataset, M: Metric>: Searchable + Sized { + /// Get the search hints for a specific item by index. + fn hints_for(&self, i: usize) -> &HashMap; + + /// Get the search hints for a specific item by index as mutable. + fn hints_for_mut(&mut self, i: usize) -> &mut HashMap; + + /// Deletes the hints for the indexed item. + fn clear_hints_for(&mut self, i: usize) { + self.hints_for_mut(i).clear(); + } + + /// Deletes all hints for all items. + fn clear_all_hints(&mut self) { + (0..self.cardinality()).for_each(|i| self.clear_hints_for(i)); + } + + /// Add a hint for the indexed item. + #[must_use] + fn with_hint_for(mut self, i: usize, k: usize, d: T) -> Self { + self.hints_for_mut(i).insert(k, d); + self + } + + /// Add hints from a tree. + /// + /// For each item in the tree, the number of known neighbors and the + /// distance to the farthest known neighbor are added as hints. + #[must_use] + fn with_hints_from_tree(self, root: &C, _: &M) -> Self { + root.subtree() + .into_iter() + .filter(|c| c.radius() > T::ZERO) + .map(|c| (c.arg_center(), c.cardinality(), c.radius())) + .fold(self, |data, (i, k, d)| data.with_hint_for(i, k, d)) + } + + /// Add hints using a search algorithm. + /// + /// # Arguments + /// + /// * `metric` - The metric to use for the search. + /// * `root` - The root of the search tree. + /// * `alg` - The search algorithm to use. + /// * `q` - The index of the query item. + #[must_use] + fn with_hints_by_search>(self, metric: &M, root: &C, alg: A) -> Self { + (0..self.cardinality()) + .flat_map(|i| { + let mut hits = alg + .search(&self, metric, root, self.get(i)) + .into_iter() + .filter(|&(_, d)| d > T::ZERO) + .collect::>(); + hits.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + hits.into_iter().enumerate().map(move |(k, (_, d))| (i, k, d)) + }) + .collect::>() + .into_iter() + .fold(self, |data, (i, k, d)| data.with_hint_for(i, k, d)) + } + + /// Add hints from a tree and several search algorithms. + #[must_use] + fn with_hints_from(self, metric: &M, root: &C, radius: T, k: usize) -> Self { + self.with_hints_from_tree(root, metric) + .with_hints_by_search(metric, root, RnnClustered(radius)) + .with_hints_by_search(metric, root, KnnDepthFirst(k)) + } +} + +/// Parallel version of [`HintedDataset`](crate::cakes::dataset::hinted::HintedDataset). +#[allow(clippy::module_name_repetitions)] +pub trait ParHintedDataset, M: ParMetric>: + HintedDataset + ParSearchable +{ + /// Parallel version of [`HintedDataset::with_hints_by_search`](crate::cakes::dataset::hinted::HintedDataset::with_hints_by_search). + #[must_use] + fn par_with_hints_by_search>(self, metric: &M, root: &C, alg: A) -> Self { + // todo!() + (0..self.cardinality()) + .into_par_iter() + .flat_map(|i| { + let mut hits = alg + .par_search(&self, metric, root, self.get(i)) + .into_par_iter() + .filter(|&(_, d)| d > T::ZERO) + .collect::>(); + hits.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + hits.into_par_iter().enumerate().map(move |(k, (_, d))| (i, k, d)) + }) + .collect::>() + .into_iter() + .fold(self, |data, (i, k, d)| data.with_hint_for(i, k, d)) + } + + /// Parallel version of [`HintedDataset::with_hints_from`](crate::cakes::dataset::hinted::HintedDataset::with_hints_from). + #[must_use] + fn par_with_hints_from(self, metric: &M, root: &C, radius: T, k: usize) -> Self { + self.with_hints_from_tree(root, metric) + .par_with_hints_by_search(metric, root, RnnClustered(radius)) + .par_with_hints_by_search(metric, root, KnnDepthFirst(k)) + } +} diff --git a/crates/abd-clam/src/cakes/dataset/mod.rs b/crates/abd-clam/src/cakes/dataset/mod.rs index b788296ad..277589d2e 100644 --- a/crates/abd-clam/src/cakes/dataset/mod.rs +++ b/crates/abd-clam/src/cakes/dataset/mod.rs @@ -1,14 +1,67 @@ -//! Extension traits of `Dataset` for specific search applications. +//! Dataset extensions for search. -mod searchable; -mod shardable; +use std::collections::HashMap; use distances::Number; +use rayon::prelude::*; + +use crate::{cluster::ParCluster, metric::ParMetric, Cluster, Dataset, FlatVec, Metric}; + +mod hinted; +mod searchable; + +#[allow(clippy::module_name_repetitions)] +pub use hinted::{HintedDataset, ParHintedDataset}; pub use searchable::{ParSearchable, Searchable}; -pub use shardable::Shardable; -use crate::{cluster::ParCluster, Cluster, FlatVec}; +impl, M: Metric, Me> Searchable for FlatVec { + fn query_to_center(&self, metric: &M, query: &I, cluster: &C) -> T { + metric.distance(query, self.get(cluster.arg_center())) + } + + #[inline(never)] + fn query_to_all(&self, metric: &M, query: &I, cluster: &C) -> impl Iterator { + cluster + .indices() + .into_iter() + .map(|i| (i, metric.distance(query, self.get(i)))) + } +} + +impl, M: ParMetric, Me: Send + Sync> ParSearchable + for FlatVec +{ + fn par_query_to_center(&self, metric: &M, query: &I, cluster: &C) -> T { + metric.par_distance(query, self.get(cluster.arg_center())) + } + + fn par_query_to_all( + &self, + metric: &M, + query: &I, + cluster: &C, + ) -> impl rayon::prelude::ParallelIterator { + cluster + .par_indices() + .map(|i| (i, metric.par_distance(query, self.get(i)))) + } +} + +#[allow(clippy::implicit_hasher)] +impl, M: Metric, Me> HintedDataset + for FlatVec)> +{ + fn hints_for(&self, i: usize) -> &HashMap { + &self.metadata[i].1 + } -impl, M> Searchable for FlatVec {} + fn hints_for_mut(&mut self, i: usize) -> &mut HashMap { + &mut self.metadata[i].1 + } +} -impl, M: Send + Sync> ParSearchable for FlatVec {} +#[allow(clippy::implicit_hasher)] +impl, M: ParMetric, Me: Send + Sync> ParHintedDataset + for FlatVec)> +{ +} diff --git a/crates/abd-clam/src/cakes/dataset/searchable.rs b/crates/abd-clam/src/cakes/dataset/searchable.rs index 022352130..d5f81da25 100644 --- a/crates/abd-clam/src/cakes/dataset/searchable.rs +++ b/crates/abd-clam/src/cakes/dataset/searchable.rs @@ -1,49 +1,27 @@ -//! Searchable dataset. +//! An extension of `Dataset` that supports search operations. use distances::Number; -use rayon::prelude::*; +use rayon::iter::ParallelIterator; -use crate::{cakes::Algorithm, cluster::ParCluster, dataset::ParDataset, Cluster, Dataset}; +use crate::{cluster::ParCluster, dataset::ParDataset, metric::ParMetric, Cluster, Dataset, Metric}; -/// A dataset that can be searched with entropy-scaling algorithms. -pub trait Searchable>: Dataset + Sized { - /// Searches the dataset for the `query` instance and returns the - /// indices of and distances to the nearest neighbors. - fn search(&self, root: &C, query: &I, alg: Algorithm) -> Vec<(usize, U)> { - alg.search(self, root, query) - } +/// A dataset that supports search operations. +pub trait Searchable, M: Metric>: Dataset { + /// Returns the distance from a query to the center of the given cluster. + fn query_to_center(&self, metric: &M, query: &I, cluster: &C) -> T; - /// Batch version of the `search` method, to search for multiple queries. - fn batch_search(&self, root: &C, queries: &[I], alg: Algorithm) -> Vec> { - queries.iter().map(|query| self.search(root, query, alg)).collect() - } + /// Returns the distances from a query to all items in the given cluster. + fn query_to_all(&self, metric: &M, query: &I, cluster: &C) -> impl Iterator; } -/// Parallel version of the `Searchable` trait. +/// A parallel version of [`Searchable`](crate::cakes::dataset::searchable::Searchable). #[allow(clippy::module_name_repetitions)] -pub trait ParSearchable>: - Searchable + ParDataset +pub trait ParSearchable, M: ParMetric>: + ParDataset + Searchable { - /// Parallel version of the `search` method. - fn par_search(&self, root: &C, query: &I, alg: Algorithm) -> Vec<(usize, U)> { - alg.par_search(self, root, query) - } + /// Parallel version of [`Searchable::query_to_center`](crate::cakes::dataset::searchable::Searchable::query_to_center). + fn par_query_to_center(&self, metric: &M, query: &I, cluster: &C) -> T; - /// Batch version of the `par_search` method. - fn batch_par_search(&self, root: &C, queries: &[I], alg: Algorithm) -> Vec> { - queries.iter().map(|query| self.par_search(root, query, alg)).collect() - } - - /// Parallel version of the `batch_search` method. - fn par_batch_search(&self, root: &C, queries: &[I], alg: Algorithm) -> Vec> { - queries.par_iter().map(|query| self.search(root, query, alg)).collect() - } - - /// Parallel version of the `batch_par_search` method. - fn par_batch_par_search(&self, root: &C, queries: &[I], alg: Algorithm) -> Vec> { - queries - .par_iter() - .map(|query| self.par_search(root, query, alg)) - .collect() - } + /// Parallel version of [`Searchable::query_to_all`](crate::cakes::dataset::searchable::Searchable::query_to_all). + fn par_query_to_all(&self, metric: &M, query: &I, cluster: &C) -> impl ParallelIterator; } diff --git a/crates/abd-clam/src/cakes/dataset/shardable.rs b/crates/abd-clam/src/cakes/dataset/shardable.rs deleted file mode 100644 index 62adc91b5..000000000 --- a/crates/abd-clam/src/cakes/dataset/shardable.rs +++ /dev/null @@ -1,103 +0,0 @@ -//! A `Dataset` that can be sharded into multiple smaller datasets. - -use distances::Number; - -use crate::core::{Dataset, FlatVec}; - -/// A dataset that can be sharded into multiple smaller datasets. -pub trait Shardable: Dataset { - /// Sets the permutation of the dataset to the identity permutation. - /// - /// This should not change the order of the instances. - #[must_use] - fn reset_permutation(self) -> Self - where - Self: Sized; - - /// Split the dataset into two smaller datasets, at the given index. - /// - /// If the `Dataset` is `Permutable`, the permutation should be ignored. - /// - /// # Arguments - /// - /// * `at` - The index at which to split the dataset. - /// - /// # Returns - /// - /// - The dataset containing instances in the range `[0, at)`. - /// - The dataset containing instances in the range `[at, cardinality)`. - fn split_off(self, at: usize) -> [Self; 2] - where - Self: Sized; - - /// Shard the dataset into a number of smaller datasets. - /// - /// This will erase the permutation of the dataset. - /// - /// # Arguments - /// - /// * `at` - The indices at which to shard the dataset. - /// - /// # Returns - /// - /// A vector of datasets, each containing instances in the range `[at[i], at[i+1])`. - fn shard(mut self, at: &[usize]) -> Vec - where - Self: Sized, - { - let mut shards = Vec::with_capacity(at.len() + 1); - // Iterate over `at` in reverse order - for &i in at.iter().rev() { - let [left, right] = self.split_off(i); - shards.push(right); - self = left; - } - shards.push(self); - shards.reverse(); - shards.into_iter().map(Self::reset_permutation).collect() - } - - /// Shard the dataset into a number of smaller datasets, with each shard - /// containing an equal number of instances, except possibly the last shard. - /// - /// # Arguments - /// - /// * `size` - The number of instances in each shard. - /// - /// # Returns - /// - /// A vector of datasets, each containing `size` instances, except possibly - /// the last shard. - fn shard_evenly(self, size: usize) -> Vec - where - Self: Sized, - { - let at = (0..self.cardinality()).step_by(size).collect::>(); - self.shard(&at) - } -} - -impl Shardable for FlatVec { - #[must_use] - fn reset_permutation(mut self) -> Self { - self.permutation = (0..self.instances.len()).collect(); - self - } - - fn split_off(mut self, at: usize) -> [Self; 2] { - #[allow(clippy::unnecessary_struct_initialization)] - let metric = self.metric.clone(); - let instances = self.instances.split_off(at); - let permutation = Vec::new(); - let metadata = self.metadata.split_off(at); - let right_data = Self { - metric, - instances, - dimensionality_hint: self.dimensionality_hint, - permutation, - metadata, - name: self.name.clone(), - }; - [self, right_data] - } -} diff --git a/crates/abd-clam/src/cakes/mod.rs b/crates/abd-clam/src/cakes/mod.rs index 8ae0ad0e9..41124537f 100644 --- a/crates/abd-clam/src/cakes/mod.rs +++ b/crates/abd-clam/src/cakes/mod.rs @@ -1,270 +1,12 @@ //! Entropy Scaling Search -pub mod cluster; -pub(crate) mod codec; -pub mod dataset; +mod cluster; +mod dataset; mod search; -pub use cluster::OffBall; -pub use codec::{ - CodecData, Compressible, Decodable, Decompressible, Encodable, ParCompressible, ParDecompressible, SquishyBall, +pub use cluster::{Offset, PermutedBall}; +pub use dataset::{HintedDataset, ParHintedDataset, ParSearchable, Searchable}; +pub use search::{ + KnnBreadthFirst, KnnDepthFirst, KnnHinted, KnnLinear, KnnRepeatedRnn, ParSearchAlgorithm, RnnClustered, RnnLinear, + SearchAlgorithm, }; -pub use dataset::Shardable; -pub use search::Algorithm; - -#[cfg(test)] -pub mod tests { - use core::fmt::Debug; - - use distances::{number::Float, Number}; - use rand::prelude::*; - use test_case::test_case; - - pub type Algs = Vec<( - super::Algorithm, - fn(Vec<(usize, U)>, Vec<(usize, U)>, &str, &FlatVec) -> bool, - )>; - - use crate::{ - adapter::{BallAdapter, ParBallAdapter}, - cakes::{OffBall, SquishyBall}, - cluster::ParCluster, - dataset::ParDataset, - Ball, Cluster, Dataset, FlatVec, Metric, Partition, - }; - - pub fn gen_line_data(max: i32) -> Result, String> { - let data = (-max..=max).collect::>(); - let distance_fn = |a: &i32, b: &i32| a.abs_diff(*b); - let metric = Metric::new(distance_fn, false); - FlatVec::new(data, metric) - } - - pub fn gen_grid_data(max: i32) -> Result, String> { - let data = (-max..=max) - .flat_map(|x| (-max..=max).map(move |y| (x.as_f32(), y.as_f32()))) - .collect::>(); - let distance_fn = |(x1, y1): &(f32, f32), (x2, y2): &(f32, f32)| (x1 - x2).hypot(y1 - y2); - let metric = Metric::new(distance_fn, false); - FlatVec::new(data, metric) - } - - pub fn check_search_by_index( - mut true_hits: Vec<(usize, U)>, - mut pred_hits: Vec<(usize, U)>, - name: &str, - data: &FlatVec, - ) -> bool { - // true_hits.sort_by(|(i, p), (j, q)| p.partial_cmp(q).unwrap_or(core::cmp::Ordering::Greater).then(i.cmp(j))); - // pred_hits.sort_by(|(i, p), (j, q)| p.partial_cmp(q).unwrap_or(core::cmp::Ordering::Greater).then(i.cmp(j))); - - true_hits.sort_by_key(|(i, _)| *i); - pred_hits.sort_by_key(|(i, _)| *i); - - let rest = format!("\n{true_hits:?}\nvs\n{pred_hits:?}"); - assert_eq!(true_hits.len(), pred_hits.len(), "{name}: {rest}"); - - for ((i, p), (j, q)) in true_hits.into_iter().zip(pred_hits) { - let msg = format!("Failed {name} i: {i}, j: {j}, p: {p}, q: {q}"); - assert_eq!(i, j, "{msg} {rest}"); - let (l, r) = (data.get(i), data.get(j)); - assert!(p.abs_diff(q) <= U::EPSILON, "{msg} in {rest}.\n{l:?} vs \n{r:?}"); - } - - true - } - - pub fn check_search_by_distance( - mut true_hits: Vec<(usize, U)>, - mut pred_hits: Vec<(usize, U)>, - name: &str, - data: &FlatVec, - ) -> bool { - true_hits.sort_by(|(_, p), (_, q)| p.partial_cmp(q).unwrap_or(core::cmp::Ordering::Greater)); - pred_hits.sort_by(|(_, p), (_, q)| p.partial_cmp(q).unwrap_or(core::cmp::Ordering::Greater)); - - assert_eq!( - true_hits.len(), - pred_hits.len(), - "{name}: {true_hits:?} vs {pred_hits:?}" - ); - - for (i, (&(l, p), &(r, q))) in true_hits.iter().zip(pred_hits.iter()).enumerate() { - let (l, r) = (data.get(l), data.get(r)); - assert!( - p.abs_diff(q) <= U::EPSILON, - "Failed {name} i-th: {i}, p: {p}, q: {q} in {true_hits:?} vs {pred_hits:?}.\n{l:?} vs \n{r:?}" - ); - } - - true - } - - pub fn gen_random_data( - car: usize, - dim: usize, - max: F, - seed: u64, - ) -> Result, F, usize>, String> { - let mut rng = rand::rngs::StdRng::seed_from_u64(seed); - let data = symagen::random_data::random_tabular(car, dim, -max, max, &mut rng); - let distance_fn = |a: &Vec, b: &Vec| distances::vectors::euclidean(a, b); - let metric = Metric::new(distance_fn, false); - FlatVec::new(data, metric) - } - - pub fn check_search( - algs: &Algs, - data: &D, - root: &C, - query: &I, - name: &str, - fv_data: &FlatVec, - ) -> bool - where - I: Send + Sync, - U: Number, - D: ParDataset, - C: ParCluster, - { - for (alg, checker) in algs { - let true_hits = alg.linear_variant().par_search(data, root, query); - let pred_hits = alg.par_search(data, root, query); - let alg_name = format!("{name}-{}", alg.name()); - checker(true_hits.clone(), pred_hits, &alg_name, fv_data); - } - - true - } - - #[test_case(1_000, 10)] - #[test_case(10_000, 10)] - #[test_case(1_000, 100)] - #[test_case(10_000, 100)] - fn vectors(car: usize, dim: usize) -> Result<(), String> { - let mut algs: Algs, f32, usize> = vec![]; - for radius in [0.1, 1.0] { - algs.push((super::Algorithm::RnnClustered(radius), check_search_by_index)); - } - for k in [1, 10, 100] { - algs.push((super::Algorithm::KnnRepeatedRnn(k, 2.0), check_search_by_distance)); - algs.push((super::Algorithm::KnnBreadthFirst(k), check_search_by_distance)); - algs.push((super::Algorithm::KnnDepthFirst(k), check_search_by_distance)); - } - - let seed = 42; - let data = gen_random_data(car, dim, 10.0, seed)?; - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(seed); - let query = &vec![0.0; dim]; - - let ball = Ball::new_tree(&data, &criteria, seed); - check_search(&algs, &data, &ball, query, "ball", &data); - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball.clone(), data.clone()); - check_search(&algs, &perm_data, &off_ball, query, "off_ball", &perm_data); - - let (par_off_ball, per_perm_data) = OffBall::par_from_ball_tree(ball, data); - check_search( - &algs, - &per_perm_data, - &par_off_ball, - query, - "par_off_ball", - &per_perm_data, - ); - - Ok(()) - } - - #[test_case::test_case(16, 16, 2)] - fn strings(num_clumps: usize, clump_size: usize, clump_radius: u16) -> Result<(), String> { - let pool = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap(); - - pool.install(|| { - let mut algs: Algs = vec![]; - for radius in [4, 8, 16] { - algs.push((super::Algorithm::RnnClustered(radius), check_search_by_index)); - } - for k in [1, 10, 20] { - algs.push((super::Algorithm::KnnRepeatedRnn(k, 2), check_search_by_distance)); - algs.push((super::Algorithm::KnnBreadthFirst(k), check_search_by_distance)); - algs.push((super::Algorithm::KnnDepthFirst(k), check_search_by_distance)); - } - - let seed_length = 30; - let alphabet = "ACTGN".chars().collect::>(); - let seed_string = symagen::random_edits::generate_random_string(seed_length, &alphabet); - let penalties = distances::strings::Penalties::default(); - let inter_clump_distance_range = (clump_radius * 5, clump_radius * 7); - let len_delta = seed_length / 10; - let (metadata, data) = symagen::random_edits::generate_clumped_data( - &seed_string, - penalties, - &alphabet, - num_clumps, - clump_size, - clump_radius, - inter_clump_distance_range, - len_delta, - ) - .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(); - let query = &seed_string; - - let distance_fn = |a: &String, b: &String| distances::strings::levenshtein::(a, b); - let metric = Metric::new(distance_fn, true); - let data = FlatVec::new(data, metric)?.with_metadata(metadata.clone())?; - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &criteria, seed); - check_search(&algs, &data, &ball, query, "ball", &data); - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball.clone(), data.clone()); - check_search(&algs, &perm_data, &off_ball, query, "off_ball", &perm_data); - - let (par_off_ball, par_perm_data) = OffBall::par_from_ball_tree(ball.clone(), data.clone()); - check_search( - &algs, - &par_perm_data, - &par_off_ball, - query, - "par_off_ball", - &par_perm_data, - ); - - let (squishy_ball, co_data) = SquishyBall::from_ball_tree(ball.clone(), data.clone()); - let (squishy_ball, co_data) = { - ( - squishy_ball.with_metadata_type::(), - co_data.with_metadata(metadata.clone())?, - ) - }; - let co_fv_data = co_data.to_flat_vec(); - check_search(&algs, &co_data, &squishy_ball, query, "squishy_ball", &co_fv_data); - - let (par_squishy_ball, par_co_data) = SquishyBall::par_from_ball_tree(ball, data); - let (par_squishy_ball, par_co_data) = { - ( - par_squishy_ball.with_metadata_type::(), - par_co_data.with_metadata(metadata.clone())?, - ) - }; - let par_co_fv_data = par_co_data.to_flat_vec(); - check_search( - &algs, - &par_co_data, - &par_squishy_ball, - query, - "par_squishy_ball", - &par_co_fv_data, - ); - - Ok::<_, String>(()) - })?; - - Ok(()) - } -} diff --git a/crates/abd-clam/src/cakes/search/knn_breadth_first.rs b/crates/abd-clam/src/cakes/search/knn_breadth_first.rs index 4d4f1beac..baae12f1d 100644 --- a/crates/abd-clam/src/cakes/search/knn_breadth_first.rs +++ b/crates/abd-clam/src/cakes/search/knn_breadth_first.rs @@ -6,134 +6,125 @@ use distances::Number; use rayon::prelude::*; use crate::{ + cakes::{ParSearchable, Searchable}, cluster::ParCluster, - dataset::{ParDataset, SizedHeap}, - Cluster, Dataset, + metric::ParMetric, + Cluster, Metric, SizedHeap, }; +use super::{ParSearchAlgorithm, SearchAlgorithm}; + /// K-Nearest Neighbors search using a Breadth First sieve. -pub fn search(data: &D, root: &C, query: &I, k: usize) -> Vec<(usize, U)> -where - U: Number, - D: Dataset, - C: Cluster, -{ - let mut candidates = Vec::new(); - let mut hits = SizedHeap::<(U, usize)>::new(Some(k)); +pub struct KnnBreadthFirst(pub usize); - let d = root.distance_to_center(data, query); - candidates.push((d_max(root, d), root)); +impl, M: Metric, D: Searchable> SearchAlgorithm + for KnnBreadthFirst +{ + fn name(&self) -> &str { + "KnnBreadthFirst" + } - while !candidates.is_empty() { - let [needed, maybe_needed, _] = split_candidates(&mut candidates, k); + fn radius(&self) -> Option { + None + } - let (leaves, parents) = needed - .into_iter() - .chain(maybe_needed) - .partition::, _>(|(_, c)| c.is_leaf()); + fn k(&self) -> Option { + Some(self.0) + } - for (d, c) in leaves { - if c.is_singleton() { - c.indices().for_each(|i| hits.push((d, i))); - } else { - c.distances_to_query(data, query) - .into_iter() - .for_each(|(i, d)| hits.push((d, i))); - } + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + let mut candidates = Vec::new(); + let mut hits = SizedHeap::<(T, usize)>::new(Some(self.0)); + + let d = data.query_to_center(metric, query, root); + candidates.push((d_max(root, d), root)); + + while !candidates.is_empty() { + candidates = filter_candidates(candidates, self.0) + .into_iter() + .fold(Vec::new(), |mut acc, (d, c)| { + if (c.cardinality() < (self.0 - acc.len())) || c.is_leaf() { + if c.is_singleton() { + c.indices().into_iter().for_each(|i| hits.push((d, i))); + } else { + data.query_to_all(metric, query, c).for_each(|(i, d)| hits.push((d, i))); + } + } else { + acc.extend( + c.children() + .into_iter() + .map(|c| (d_max(c, data.query_to_center(metric, query, c)), c)), + ); + } + acc + }); } - candidates = Vec::new(); - for (_, p) in parents { - p.child_clusters() - .map(|c| (c, c.distance_to_center(data, query))) - .for_each(|(c, d)| candidates.push((d_max(c, d), c))); - } + hits.items().map(|(d, i)| (i, d)).collect() } - - hits.items().map(|(d, i)| (i, d)).collect() } -/// Parallel K-Nearest Neighbors search using a Breadth First sieve. -pub fn par_search(data: &D, root: &C, query: &I, k: usize) -> Vec<(usize, U)> -where - I: Send + Sync, - U: Number, - D: ParDataset, - C: ParCluster, +impl, M: ParMetric, D: ParSearchable> + ParSearchAlgorithm for KnnBreadthFirst { - let mut candidates = Vec::new(); - let mut hits = SizedHeap::<(U, usize)>::new(Some(k)); - - let d = root.distance_to_center(data, query); - candidates.push((d_max(root, d), root)); - - while !candidates.is_empty() { - let [needed, maybe_needed, _] = split_candidates(&mut candidates, k); - - let (leaves, parents) = needed - .into_iter() - .chain(maybe_needed) - .partition::, _>(|(_, c)| c.is_leaf()); - - for (d, c) in leaves { - if c.is_singleton() { - c.indices().for_each(|i| hits.push((d, i))); - } else { - c.par_distances_to_query(data, query) - .into_iter() - .for_each(|(i, d)| hits.push((d, i))); - } + #[inline(never)] + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + let mut candidates = Vec::new(); + let mut hits = SizedHeap::<(T, usize)>::new(Some(self.0)); + + let d = data.par_query_to_center(metric, query, root); + candidates.push((d_max(root, d), root)); + + while !candidates.is_empty() { + candidates = filter_candidates(candidates, self.0) + .into_iter() + .fold(Vec::new(), |mut acc, (d, c)| { + if (c.cardinality() < (self.0 - acc.len())) || c.is_leaf() { + if c.is_singleton() { + c.indices().into_iter().for_each(|i| hits.push((d, i))); + } else { + let distances = data.par_query_to_all(metric, query, c).collect::>(); + distances.into_iter().for_each(|(i, d)| hits.push((d, i))); + } + } else { + let distances = c + .children() + .into_par_iter() + .map(|c| (d_max(c, data.par_query_to_center(metric, query, c)), c)) + .collect::>(); + acc.extend(distances); + } + acc + }); } - candidates = Vec::new(); - let distances = parents - .into_par_iter() - .flat_map(|(_, p)| p.child_clusters().collect::>()) - .map(|c| (c, c.distance_to_center(data, query))) - .collect::>(); - distances - .into_iter() - .for_each(|(c, d)| candidates.push((d_max(c, d), c))); + hits.items().map(|(d, i)| (i, d)).collect() } - - hits.items().map(|(d, i)| (i, d)).collect() } /// Returns the theoretical maximum distance from the query to a point in the cluster. -fn d_max, C: Cluster>(c: &C, d: U) -> U { +fn d_max>(c: &C, d: T) -> T { c.radius() + d } -/// Splits the candidates three ways: those needed to get to k hits, those that -/// might be needed to get to k hits, and those that are not needed to get to k -/// hits. -fn split_candidates<'a, I, U, D, C>(candidates: &mut [(U, &'a C)], k: usize) -> [Vec<(U, &'a C)>; 3] -where - U: Number, - D: Dataset, - C: Cluster, -{ - let threshold_index = quick_partition(candidates, k); +/// Returns those candidates that are needed to guarantee the k-nearest +/// neighbors. +fn filter_candidates>(mut candidates: Vec<(T, &C)>, k: usize) -> Vec<(T, &C)> { + let threshold_index = quick_partition(&mut candidates, k); let threshold = candidates[threshold_index].0; - let (needed, others) = candidates.iter().partition::, _>(|(d, _)| *d < threshold); - - let (not_needed, maybe_needed) = others + candidates .into_iter() - .map(|(d, c)| { + .filter_map(|(d, c)| { let diam = c.radius().double(); - if d <= diam { - (d, U::ZERO, c) + let d_min = if d <= diam { T::ZERO } else { d - diam }; + if d_min <= threshold { + Some((d, c)) } else { - (d, d - diam, c) + None } }) - .partition::, _>(|(_, d, _)| *d > threshold); - - let not_needed = not_needed.into_iter().map(|(d, _, c)| (d, c)).collect(); - let maybe_needed = maybe_needed.into_iter().map(|(d, _, c)| (d, c)).collect(); - - [needed, maybe_needed, not_needed] + .collect() } /// The Quick Partition algorithm, which is a variant of the Quick Select @@ -141,22 +132,12 @@ where /// also reordering the list so that all elements to the left of the k-th /// smallest element are less than or equal to it, and all elements to the right /// of the k-th smallest element are greater than or equal to it. -fn quick_partition(items: &mut [(U, &C)], k: usize) -> usize -where - U: Number, - D: Dataset, - C: Cluster, -{ +fn quick_partition>(items: &mut [(T, &C)], k: usize) -> usize { qps(items, k, 0, items.len() - 1) } /// The recursive helper function for the Quick Partition algorithm. -fn qps(items: &mut [(U, &C)], k: usize, l: usize, r: usize) -> usize -where - U: Number, - D: Dataset, - C: Cluster, -{ +fn qps>(items: &mut [(T, &C)], k: usize, l: usize, r: usize) -> usize { if l >= r { min(l, r) } else { @@ -198,11 +179,10 @@ where /// Moves pivot point and swaps elements around so that all elements to left /// of pivot are less than or equal to pivot and all elements to right of pivot /// are greater than pivot. -fn find_pivot(items: &mut [(U, &C)], l: usize, r: usize, pivot: usize) -> usize +fn find_pivot(items: &mut [(T, &C)], l: usize, r: usize, pivot: usize) -> usize where - U: Number, - D: Dataset, - C: Cluster, + T: Number, + C: Cluster, { // Move pivot to the end items.swap(pivot, r); @@ -226,58 +206,3 @@ where a } - -#[cfg(test)] -mod tests { - use crate::{ - adapter::BallAdapter, - cakes::OffBall, - cluster::{Ball, Partition}, - Cluster, - }; - - use super::super::knn_depth_first::tests::check_knn; - use crate::cakes::tests::{gen_grid_data, gen_line_data}; - - #[test] - fn line() -> Result<(), String> { - let data = gen_line_data(10)?; - let query = &0; - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &criteria, seed); - for k in [1, 4, 8] { - assert!(check_knn(&ball, &data, query, k)); - } - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball, data); - for k in [1, 4, 8] { - assert!(check_knn(&off_ball, &perm_data, query, k)); - } - - Ok(()) - } - - #[test] - fn grid() -> Result<(), String> { - let data = gen_grid_data(10)?; - let query = &(0.0, 0.0); - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &criteria, seed); - for k in [1, 4, 8] { - assert!(check_knn(&ball, &data, query, k)); - } - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball, data); - for k in [1, 4, 8] { - assert!(check_knn(&off_ball, &perm_data, query, k)); - } - - Ok(()) - } -} diff --git a/crates/abd-clam/src/cakes/search/knn_depth_first.rs b/crates/abd-clam/src/cakes/search/knn_depth_first.rs index b10f42fe3..10e4db51a 100644 --- a/crates/abd-clam/src/cakes/search/knn_depth_first.rs +++ b/crates/abd-clam/src/cakes/search/knn_depth_first.rs @@ -3,75 +3,90 @@ use core::cmp::Reverse; use distances::Number; +use rayon::prelude::*; use crate::{ + cakes::{ParSearchable, Searchable}, cluster::ParCluster, - dataset::{ParDataset, SizedHeap}, - Cluster, Dataset, + dataset::SizedHeap, + metric::ParMetric, + Cluster, Metric, }; +use super::{ParSearchAlgorithm, SearchAlgorithm}; + /// K-Nearest Neighbors search using a Depth First sieve. -pub fn search(data: &D, root: &C, query: &I, k: usize) -> Vec<(usize, U)> -where - U: Number, - D: Dataset, - C: Cluster, +pub struct KnnDepthFirst(pub usize); + +impl, M: Metric, D: Searchable> SearchAlgorithm + for KnnDepthFirst { - let mut candidates = SizedHeap::<(Reverse, &C)>::new(None); - let mut hits = SizedHeap::<(U, usize)>::new(Some(k)); - - let d = root.distance_to_center(data, query); - candidates.push((Reverse(d_min(root, d)), root)); - - while !hits.is_full() // We do not have enough hits. - || (!candidates.is_empty() // We have candidates. - && hits // and - .peek() - .map_or_else(|| unreachable!("`hits` is non-empty."), |(d, _)| *d) // the farthest hit - >= candidates // is farther than - .peek() // the closest candidate - .map_or_else(|| unreachable!("`candidates` is non-empty."), |(d, _)| d.0)) - { - let (d, leaf) = pop_till_leaf(data, query, &mut candidates); - leaf_into_hits(data, query, &mut hits, d, leaf); + fn name(&self) -> &str { + "KnnDepthFirst" + } + + fn radius(&self) -> Option { + None + } + + fn k(&self) -> Option { + Some(self.0) + } + + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + let mut candidates = SizedHeap::<(Reverse, &C)>::new(None); + let mut hits = SizedHeap::<(T, usize)>::new(Some(self.0)); + + let d = data.query_to_center(metric, query, root); + candidates.push((Reverse(d_min(root, d)), root)); + + while !hits.is_full() // We do not have enough hits. + || (!candidates.is_empty() // We have candidates. + && hits // and + .peek() + .map_or_else(|| unreachable!("`hits` is non-empty."), |(d, _)| *d) // the farthest hit + >= candidates // is farther than + .peek() // the closest candidate + .map_or_else(|| unreachable!("`candidates` is non-empty."), |(d, _)| d.0)) + { + let (d, leaf) = pop_till_leaf(data, metric, query, &mut candidates); + leaf_into_hits(data, metric, query, &mut hits, d, leaf); + } + hits.items().map(|(d, i)| (i, d)).collect() } - hits.items().map(|(d, i)| (i, d)).collect() } -/// Parallel K-Nearest Neighbors search using a Depth First sieve. -pub fn par_search(data: &D, root: &C, query: &I, k: usize) -> Vec<(usize, U)> -where - I: Send + Sync, - U: Number, - D: ParDataset, - C: ParCluster, +impl, M: ParMetric, D: ParSearchable> + ParSearchAlgorithm for KnnDepthFirst { - let mut candidates = SizedHeap::<(Reverse, &C)>::new(None); - let mut hits = SizedHeap::<(U, usize)>::new(Some(k)); - - let d = root.distance_to_center(data, query); - candidates.push((Reverse(d_min(root, d)), root)); - - while !hits.is_full() // We do not have enough hits. - || (!candidates.is_empty() // We have candidates. - && hits // and - .peek() - .map_or_else(|| unreachable!("`hits` is non-empty."), |(d, _)| *d) // the farthest hit - >= candidates // is farther than - .peek() // the closest candidate - .map_or_else(|| unreachable!("`candidates` is non-empty."), |(d, _)| d.0)) - { - par_pop_till_leaf(data, query, &mut candidates); - par_leaf_into_hits(data, query, &mut hits, &mut candidates); + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + let mut candidates = SizedHeap::<(Reverse, &C)>::new(None); + let mut hits = SizedHeap::<(T, usize)>::new(Some(self.0)); + + let d = data.par_query_to_center(metric, query, root); + candidates.push((Reverse(d_min(root, d)), root)); + + while !hits.is_full() // We do not have enough hits. + || (!candidates.is_empty() // We have candidates. + && hits // and + .peek() + .map_or_else(|| unreachable!("`hits` is non-empty."), |(d, _)| *d) // the farthest hit + >= candidates // is farther than + .peek() // the closest candidate + .map_or_else(|| unreachable!("`candidates` is non-empty."), |(d, _)| d.0)) + { + par_pop_till_leaf(data, metric, query, &mut candidates); + par_leaf_into_hits(data, metric, query, &mut hits, &mut candidates); + } + hits.items().map(|(d, i)| (i, d)).collect() } - hits.items().map(|(d, i)| (i, d)).collect() } /// Calculates the theoretical best case distance for a point in a cluster, i.e., /// the closest a point in a given cluster could possibly be to the query. -pub fn d_min, C: Cluster>(c: &C, d: U) -> U { +pub fn d_min>(c: &C, d: T) -> T { if d < c.radius() { - U::ZERO + T::ZERO } else { d - c.radius() } @@ -79,51 +94,43 @@ pub fn d_min, C: Cluster>(c: &C, d: U) - /// Pops from the top of `candidates` until the top candidate is a leaf cluster. /// Then, pops and returns the leaf cluster. -fn pop_till_leaf<'a, I, U, D, C>(data: &D, query: &I, candidates: &mut SizedHeap<(Reverse, &'a C)>) -> (U, &'a C) +fn pop_till_leaf<'a, I, T, C, M, D>( + data: &D, + metric: &M, + query: &I, + candidates: &mut SizedHeap<(Reverse, &'a C)>, +) -> (T, &'a C) where - U: Number + 'a, - D: Dataset, - C: Cluster, + T: Number + 'a, + C: Cluster, + M: Metric, + D: Searchable, { while candidates - .peek() // The top candidate is a leaf + .peek() // The top candidate .map_or_else(|| unreachable!("`candidates` is non-empty"), |(_, c)| !c.is_leaf()) + // is not a leaf { let parent = candidates .pop() .map_or_else(|| unreachable!("`candidates` is non-empty"), |(_, c)| c); - for child in parent.child_clusters() { - candidates.push((Reverse(d_min(child, child.distance_to_center(data, query))), child)); - } + parent.children().into_iter().for_each(|child| { + candidates.push((Reverse(d_min(child, data.query_to_center(metric, query, child))), child)); + }); } candidates .pop() .map_or_else(|| unreachable!("`candidates` is non-empty"), |(Reverse(d), c)| (d, c)) } -/// Pops from the top of `candidates` and adds its points to `hits`. -fn leaf_into_hits(data: &D, query: &I, hits: &mut SizedHeap<(U, usize)>, d: U, leaf: &C) -where - U: Number, - D: Dataset, - C: Cluster, -{ - if leaf.is_singleton() { - leaf.indices().for_each(|i| hits.push((d, i))); - } else { - leaf.distances_to_query(data, query) - .into_iter() - .for_each(|(i, d)| hits.push((d, i))); - }; -} - -/// Parallel version of `pop_till_leaf`. -fn par_pop_till_leaf<'a, I, U, D, C>(data: &D, query: &I, candidates: &mut SizedHeap<(Reverse, &'a C)>) +/// Parallel version of [`pop_till_leaf`](crate::cakes::search::knn_depth_first::pop_till_leaf). +fn par_pop_till_leaf<'a, I, T, C, M, D>(data: &D, metric: &M, query: &I, candidates: &mut SizedHeap<(Reverse, &'a C)>) where I: Send + Sync, - U: Number + 'a, - D: ParDataset, - C: ParCluster, + T: Number + 'a, + C: ParCluster, + M: ParMetric, + D: ParSearchable, { while candidates .peek() // The top candidate @@ -133,116 +140,55 @@ where let parent = candidates .pop() .map_or_else(|| unreachable!("`candidates` is non-empty"), |(_, c)| c); - let children = parent.child_clusters().collect::>(); - let indices = children.iter().map(|c| c.arg_center()).collect::>(); - data.par_query_to_many(query, &indices) + parent + .children() + .into_par_iter() + .map(|child| (child, data.par_query_to_center(metric, query, child))) + .collect::>() .into_iter() - .zip(children) - .for_each(|((_, d), c)| candidates.push((Reverse(d_min(c, d)), c))); + .for_each(|(child, d)| candidates.push((Reverse(d_min(child, d)), child))); } } -/// Parallel version of `leaf_into_hits`. -fn par_leaf_into_hits( +/// Pops from the top of `candidates` and adds its points to `hits`. +fn leaf_into_hits(data: &D, metric: &M, query: &I, hits: &mut SizedHeap<(T, usize)>, d: T, leaf: &C) +where + T: Number, + C: Cluster, + M: Metric, + D: Searchable, +{ + if leaf.is_singleton() { + leaf.indices().into_iter().for_each(|i| hits.push((d, i))); + } else { + data.query_to_all(metric, query, leaf) + .for_each(|(i, d)| hits.push((d, i))); + }; +} + +/// Parallel version of [`leaf_into_hits`](crate::cakes::search::knn_depth_first::leaf_into_hits). +fn par_leaf_into_hits( data: &D, + metric: &M, query: &I, - hits: &mut SizedHeap<(U, usize)>, - candidates: &mut SizedHeap<(Reverse, &C)>, + hits: &mut SizedHeap<(T, usize)>, + candidates: &mut SizedHeap<(Reverse, &C)>, ) where I: Send + Sync, - U: Number, - D: ParDataset, - C: ParCluster, + T: Number, + C: ParCluster, + M: ParMetric, + D: ParSearchable, { let (d, leaf) = candidates .pop() .map_or_else(|| unreachable!("`candidates` is non-empty"), |(Reverse(d), c)| (d, c)); if leaf.is_singleton() { - leaf.indices().for_each(|i| hits.push((d, i))); + leaf.indices().into_iter().for_each(|i| hits.push((d, i))); } else { - leaf.par_distances_to_query(data, query) + data.query_to_all(metric, query, leaf) + .collect::>() .into_iter() .for_each(|(i, d)| hits.push((d, i))); }; } - -#[cfg(test)] -pub(crate) mod tests { - use core::fmt::Debug; - - use distances::Number; - - use crate::{ - adapter::BallAdapter, - cakes::OffBall, - cluster::{Ball, ParCluster, Partition}, - Cluster, Dataset, FlatVec, - }; - - use crate::cakes::tests::{check_search_by_distance, gen_grid_data, gen_line_data}; - - pub fn check_knn>>( - root: &C, - data: &FlatVec, - query: &I, - k: usize, - ) -> bool { - let true_hits = data.knn(query, k); - - let pred_hits = super::search(data, root, query, k); - assert_eq!(pred_hits.len(), true_hits.len(), "Knn search failed: {pred_hits:?}"); - check_search_by_distance(true_hits.clone(), pred_hits, "KnnClustered", data); - - let pred_hits = super::par_search(data, root, query, k); - assert_eq!( - pred_hits.len(), - true_hits.len(), - "Parallel Knn search failed: {pred_hits:?}" - ); - check_search_by_distance(true_hits.clone(), pred_hits, "Par KnnClustered", data); - - true - } - - #[test] - fn line() -> Result<(), String> { - let data = gen_line_data(10)?; - let query = &0; - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &criteria, seed); - for k in [1, 4, 8] { - assert!(check_knn(&ball, &data, query, k)); - } - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball, data); - for k in [1, 4, 8] { - assert!(check_knn(&off_ball, &perm_data, query, k)); - } - - Ok(()) - } - - #[test] - fn grid() -> Result<(), String> { - let data = gen_grid_data(10)?; - let query = &(0.0, 0.0); - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &criteria, seed); - for k in [1, 4, 8] { - assert!(check_knn(&ball, &data, query, k)); - } - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball, data); - for k in [1, 4, 8] { - assert!(check_knn(&off_ball, &perm_data, query, k)); - } - - Ok(()) - } -} diff --git a/crates/abd-clam/src/cakes/search/knn_hinted.rs b/crates/abd-clam/src/cakes/search/knn_hinted.rs new file mode 100644 index 000000000..8b04b22f6 --- /dev/null +++ b/crates/abd-clam/src/cakes/search/knn_hinted.rs @@ -0,0 +1,114 @@ +//! K-NN search using a dataset with search hints. + +use core::cmp::Reverse; + +use distances::Number; + +use crate::{ + cakes::{dataset::HintedDataset, ParHintedDataset, RnnClustered}, + cluster::ParCluster, + metric::ParMetric, + Cluster, Metric, SizedHeap, LFD, +}; + +use super::{knn_depth_first::d_min, ParSearchAlgorithm, SearchAlgorithm}; + +/// K-NN search using a dataset with search hints. +pub struct KnnHinted(pub usize); + +impl SearchAlgorithm for KnnHinted +where + T: Number, + C: Cluster, + M: Metric, + D: HintedDataset, +{ + fn name(&self) -> &str { + "KnnHinted" + } + + fn radius(&self) -> Option { + None + } + + fn k(&self) -> Option { + Some(self.0) + } + + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + // The `candidates` heap contains triples of `(d_min, d, c)`. + let mut candidates = SizedHeap::<(Reverse, T, &C)>::new(None); + + let d = data.query_to_center(metric, query, root); + let mut hits = vec![(root.arg_center(), d)]; + let mut n_hits = 1; + candidates.push((Reverse(d_min(root, d)), d, root)); + + // While it is possible to have any hit among candidates that is closer + // than the current closest hit, we keep searching. + while candidates.peek().is_some_and(|&(Reverse(d), _, _)| d < hits[0].1) { + let (_, d, c) = candidates.pop().unwrap_or_else(|| unreachable!("We just peeked")); + + if d < hits[0].1 { + hits.push((c.arg_center(), d)); + hits.swap(0, n_hits); + n_hits += 1; + } + + for child in c.children() { + let d_c = data.query_to_center(metric, query, child); + candidates.push((Reverse(d_min(child, d_c)), d_c, child)); + } + } + let mut logs = vec![format!("hits: {hits:?}")]; + + let (i, r) = hits[0]; + logs.push(format!("hints: {:?}", data.hints_for(i))); + + let additive_radius = if data.hints_for(i).contains_key(&self.0) { + data.hints_for(i)[&self.0] + } else { + let hit_distances = hits.iter().map(|&(_, d)| d).collect::>(); + logs.push(format!("hit_distances: {hit_distances:?}")); + let (_, max_distance) = crate::utils::arg_max(&hit_distances).unwrap_or((0, T::ZERO)); + logs.push(format!("max_distance: {max_distance:?}")); + let max_distance = max_distance + T::EPSILON; + let lfd = LFD::from_radial_distances(&hit_distances, max_distance.half()); + logs.push(format!("lfd: {lfd:?}")); + let min_multiplier = (1 + hit_distances.len()).as_f32() / self.0.as_f32(); + let multiplier = LFD::multiplier_for_k(lfd, hit_distances.len(), self.0).max(min_multiplier + f32::EPSILON); + logs.push(format!("multiplier: {multiplier:?}")); + T::from(max_distance.as_f32() * multiplier) + }; + logs.push(format!("additive_radius: {additive_radius:?}")); + + let alg = RnnClustered(r + additive_radius); + let mut hits = alg.search(data, metric, root, query); + hits.sort_by(|(_, d1), (_, d2)| d1.total_cmp(d2)); + if hits.len() < self.0 { + for l in logs { + eprintln!("{l}"); + } + eprintln!( + "Expected at least {} hits, got {}, radius: {r}, additive: {additive_radius}", + self.0, + hits.len(), + ); + } + assert!(hits.len() >= self.0); + hits.into_iter().take(self.0).collect() + } +} + +impl ParSearchAlgorithm for KnnHinted +where + I: Send + Sync, + T: Number, + C: ParCluster, + M: ParMetric, + D: ParHintedDataset, +{ + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + self.search(data, metric, root, query) + } +} diff --git a/crates/abd-clam/src/cakes/search/knn_linear.rs b/crates/abd-clam/src/cakes/search/knn_linear.rs new file mode 100644 index 000000000..b0b1f5ae0 --- /dev/null +++ b/crates/abd-clam/src/cakes/search/knn_linear.rs @@ -0,0 +1,68 @@ +//! k-NN search using a linear scan of the dataset. + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{ParSearchable, Searchable}, + cluster::ParCluster, + metric::ParMetric, + Cluster, Metric, SizedHeap, +}; + +use super::{ParSearchAlgorithm, SearchAlgorithm}; + +/// k-NN search using a linear scan of the dataset. +pub struct KnnLinear(pub usize); + +impl, M: Metric, D: Searchable> SearchAlgorithm + for KnnLinear +{ + fn name(&self) -> &str { + "KnnLinear" + } + + fn radius(&self) -> Option { + None + } + + fn k(&self) -> Option { + Some(self.0) + } + + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + data.query_to_all(metric, query, root) + .fold(SizedHeap::new(Some(self.0)), |mut hits, (i, d)| { + hits.push((d, i)); + hits + }) + .items() + .map(|(d, i)| (i, d)) + .collect() + } +} + +impl, M: ParMetric, D: ParSearchable> + ParSearchAlgorithm for KnnLinear +{ + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + data.par_query_to_all(metric, query, root) + .fold( + || SizedHeap::new(Some(self.0)), + |mut hits, (i, d)| { + hits.push((d, i)); + hits + }, + ) + .reduce( + || SizedHeap::new(Some(self.0)), + |mut a, b| { + a.merge(b); + a + }, + ) + .par_items() + .map(|(d, i)| (i, d)) + .collect() + } +} diff --git a/crates/abd-clam/src/cakes/search/knn_repeated_rnn.rs b/crates/abd-clam/src/cakes/search/knn_repeated_rnn.rs index 6b681d3c5..9768aa076 100644 --- a/crates/abd-clam/src/cakes/search/knn_repeated_rnn.rs +++ b/crates/abd-clam/src/cakes/search/knn_repeated_rnn.rs @@ -1,193 +1,129 @@ -//! K-Nearest Neighbors search using repeated Clustered RNN search. +//! k-NN search using a linear scan of the dataset. use distances::{number::Multiplication, Number}; +use rayon::prelude::*; use crate::{ + cakes::{ParSearchable, Searchable}, cluster::ParCluster, - dataset::{ParDataset, SizedHeap}, - Cluster, Dataset, LFD, + metric::ParMetric, + Cluster, Metric, SizedHeap, }; -use super::rnn_clustered::{leaf_search, par_leaf_search, par_tree_search, tree_search}; - -/// K-Nearest Neighbors search using repeated Clustered RNN search. -pub fn search(data: &D, root: &C, query: &I, k: usize, max_multiplier: U) -> Vec<(usize, U)> -where - U: Number, - D: Dataset, - C: Cluster, -{ - let max_multiplier = max_multiplier.as_f32(); - let mut radius = root.radius().as_f32(); +use super::{ + rnn_clustered::{par_tree_search, tree_search}, + ParSearchAlgorithm, SearchAlgorithm, +}; - let mut multiplier = LFD::multiplier_for_k(root.lfd(), root.cardinality(), k).min(max_multiplier); - radius = radius.mul_add(multiplier, U::EPSILON.as_f32()); +/// k-NN search using a linear scan of the dataset. +pub struct KnnRepeatedRnn(pub usize, pub T); - let [mut confirmed, mut straddlers] = tree_search(data, root, query, U::from(radius)); +impl, M: Metric, D: Searchable> SearchAlgorithm + for KnnRepeatedRnn +{ + fn name(&self) -> &str { + "KnnRepeatedRnn" + } - let mut num_confirmed = count_hits(&confirmed); - while num_confirmed == 0 { - radius = radius.double(); - [confirmed, straddlers] = tree_search(data, root, query, U::from(radius)); - num_confirmed = count_hits(&confirmed); + fn radius(&self) -> Option { + None } - while num_confirmed < k { - let (lfd, car) = mean_lfd(&confirmed, &straddlers); - multiplier = LFD::multiplier_for_k(lfd, car, k) - .min(max_multiplier) - .max(f32::ONE + f32::EPSILON); - radius = radius.mul_add(multiplier, U::EPSILON.as_f32()); - [confirmed, straddlers] = tree_search(data, root, query, U::from(radius)); - num_confirmed = count_hits(&confirmed); + fn k(&self) -> Option { + Some(self.0) } - let mut knn = SizedHeap::new(Some(k)); - leaf_search(data, confirmed, straddlers, query, U::from(radius)) - .into_iter() - .for_each(|(i, d)| knn.push((d, i))); - knn.items().map(|(d, i)| (i, d)).collect() + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + let radii_under_k = root.radii_for_k::(self.0); + let mut radius: f32 = crate::utils::mean(&radii_under_k); + + let [mut confirmed, mut straddlers] = tree_search(data, metric, root, query, T::from(radius)); + let mut num_confirmed = count_hits(&confirmed); + + while num_confirmed < self.0 { + let multiplier = if num_confirmed == 0 { + T::ONE.double().as_f32() + } else { + lfd_multiplier(&confirmed, &straddlers, self.0) + .min(self.1.as_f32()) + .max(f32::ONE + f32::EPSILON) + }; + radius = radius.mul_add(multiplier, T::EPSILON.as_f32()); + [confirmed, straddlers] = tree_search(data, metric, root, query, T::from(radius)); + num_confirmed = count_hits(&confirmed); + } + + let mut knn = SizedHeap::new(Some(self.0)); + for (leaf, d) in confirmed.into_iter().chain(straddlers) { + if knn.len() < self.0 // We don't have enough items yet. OR + // The current farthest hit is farther than the closest + // potential item in the leaf. + || d + leaf.radius() <= knn.peek().map_or(T::MAX, |&(d, _)| d) + { + if leaf.is_singleton() { + knn.extend(leaf.indices().into_iter().map(|i| (d, i))); + } else { + knn.extend(data.query_to_all(metric, query, leaf).map(|(i, d)| (d, i))); + } + } + } + knn.items().map(|(d, i)| (i, d)).collect() + } } -/// Parallel K-Nearest Neighbors search using repeated Clustered RNN search. -pub fn par_search(data: &D, root: &C, query: &I, k: usize, max_multiplier: U) -> Vec<(usize, U)> -where - I: Send + Sync, - U: Number, - D: ParDataset, - C: ParCluster, +impl, M: ParMetric, D: ParSearchable> + ParSearchAlgorithm for KnnRepeatedRnn { - let max_multiplier = max_multiplier.as_f32(); - let mut radius = root.radius().as_f32(); - - let mut multiplier = LFD::multiplier_for_k(root.lfd(), root.cardinality(), k).min(max_multiplier); - radius = radius.mul_add(multiplier, U::EPSILON.as_f32()); - - let [mut confirmed, mut straddlers] = par_tree_search(data, root, query, U::from(radius)); - - let mut num_confirmed = count_hits(&confirmed); - while num_confirmed == 0 { - radius = radius.double(); - [confirmed, straddlers] = par_tree_search(data, root, query, U::from(radius)); - num_confirmed = count_hits(&confirmed); - } + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + let radii_under_k = root.radii_for_k::(self.0); + let mut radius: f32 = crate::utils::mean(&radii_under_k); + + let [mut confirmed, mut straddlers] = par_tree_search(data, metric, root, query, T::from(radius)); + let mut num_confirmed = count_hits(&confirmed); + + while num_confirmed < self.0 { + let multiplier = if num_confirmed == 0 { + T::ONE.double().as_f32() + } else { + lfd_multiplier(&confirmed, &straddlers, self.0) + .min(self.1.as_f32()) + .max(f32::ONE + f32::EPSILON) + }; + radius = radius.mul_add(multiplier, T::EPSILON.as_f32()); + [confirmed, straddlers] = par_tree_search(data, metric, root, query, T::from(radius)); + num_confirmed = count_hits(&confirmed); + } - while num_confirmed < k { - let (lfd, car) = mean_lfd(&confirmed, &straddlers); - multiplier = LFD::multiplier_for_k(lfd, car, k) - .min(max_multiplier) - .max(f32::ONE + f32::EPSILON); - radius = radius.mul_add(multiplier, U::EPSILON.as_f32()); - [confirmed, straddlers] = par_tree_search(data, root, query, U::from(radius)); - num_confirmed = count_hits(&confirmed); + let mut knn = SizedHeap::new(Some(self.0)); + for (leaf, d) in confirmed.into_iter().chain(straddlers) { + if knn.len() < self.0 // We don't have enough items yet. OR + // The current farthest hit is farther than the closest + // potential item in the leaf. + || d + leaf.radius() <= knn.peek().map_or(T::MAX, |&(d, _)| d) + { + if leaf.is_singleton() { + knn.extend(leaf.indices().into_iter().map(|i| (d, i))); + } else { + knn.par_extend(data.par_query_to_all(metric, query, leaf).map(|(i, d)| (d, i))); + } + } + } + knn.items().map(|(d, i)| (i, d)).collect() } - - let mut knn = SizedHeap::new(Some(k)); - par_leaf_search(data, confirmed, straddlers, query, U::from(radius)) - .into_iter() - .for_each(|(i, d)| knn.push((d, i))); - knn.items().map(|(d, i)| (i, d)).collect() } /// Count the total cardinality of the clusters. -fn count_hits, C: Cluster>(hits: &[(&C, U)]) -> usize { +fn count_hits>(hits: &[(&C, T)]) -> usize { hits.iter().map(|(c, _)| c.cardinality()).sum() } -/// Calculate the weighted mean of the LFDs of the clusters. -fn mean_lfd, C: Cluster>( - confirmed: &[(&C, U)], - straddlers: &[(&C, U)], -) -> (f32, usize) { - let (lfd, car) = confirmed +/// Calculate a multiplier for the radius using the LFDs of the clusters. +fn lfd_multiplier>(confirmed: &[(&C, T)], straddlers: &[(&C, T)], k: usize) -> f32 { + let (mu, car) = confirmed .iter() - .map(|&(c, _)| c) - .chain(straddlers.iter().map(|&(c, _)| c)) - .map(|c| (c.lfd(), c.cardinality())) - .fold((0.0, 0), |(lfd, car), (l, c)| (l.mul_add(c.as_f32(), lfd), car + c)); - (lfd / car.as_f32(), car) -} - -#[cfg(test)] -mod tests { - use core::fmt::Debug; - - use distances::Number; - - use crate::{ - adapter::BallAdapter, - cakes::OffBall, - cluster::{Ball, ParCluster, Partition}, - Cluster, Dataset, FlatVec, - }; - - use crate::cakes::tests::{check_search_by_distance, gen_grid_data, gen_line_data}; - - fn check_knn>>( - root: &C, - data: &FlatVec, - query: &I, - k: usize, - max_multiplier: U, - ) -> bool { - let true_hits = data.knn(query, k); - - let pred_hits = super::search(data, root, query, k, max_multiplier); - assert_eq!(pred_hits.len(), true_hits.len(), "Knn search failed: {pred_hits:?}"); - check_search_by_distance(true_hits.clone(), pred_hits, "KnnClustered", data); - - let pred_hits = super::par_search(data, root, query, k, max_multiplier); - assert_eq!( - pred_hits.len(), - true_hits.len(), - "Parallel Knn search failed: {pred_hits:?}" - ); - check_search_by_distance(true_hits, pred_hits, "Par KnnClustered", data); - - true - } - - #[test] - fn line() -> Result<(), String> { - let data = gen_line_data(10)?; - let query = &0; - let max_multiplier = 2; - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &criteria, seed); - for k in [1, 4, 8] { - assert!(check_knn(&ball, &data, query, k, max_multiplier)); - } - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball, data); - for k in [1, 4, 8] { - assert!(check_knn(&off_ball, &perm_data, query, k, max_multiplier)); - } - - Ok(()) - } - - #[test] - fn grid() -> Result<(), String> { - let data = gen_grid_data(10)?; - let query = &(0.0, 0.0); - let max_multiplier = 2.0; - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &criteria, seed); - for k in [1, 4, 8, 16, 32] { - assert!(check_knn(&ball, &data, query, k, max_multiplier)); - } - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball, data); - for k in [1, 4, 8, 16, 32] { - assert!(check_knn(&off_ball, &perm_data, query, k, max_multiplier)); - } - - Ok(()) - } + .chain(straddlers.iter()) + .map(|&(c, _)| (c.lfd().recip(), c.cardinality())) + .fold((0.0, 0), |(lfd, car), (l, c)| (lfd + l, car + c)); + let mu = mu / (confirmed.len() + straddlers.len()).as_f32(); + (k.as_f32() / car.as_f32()).powf(mu) } diff --git a/crates/abd-clam/src/cakes/search/mod.rs b/crates/abd-clam/src/cakes/search/mod.rs index 501176df8..c7f110526 100644 --- a/crates/abd-clam/src/cakes/search/mod.rs +++ b/crates/abd-clam/src/cakes/search/mod.rs @@ -1,242 +1,132 @@ -//! Entropy scaling search algorithms. +//! Entropy scaling search algorithms and supporting traits. + +use distances::Number; +use rayon::prelude::*; + +use crate::{cluster::ParCluster, metric::ParMetric, Cluster, Metric}; + +use super::{ParSearchable, Searchable}; mod knn_breadth_first; mod knn_depth_first; +mod knn_hinted; +mod knn_linear; mod knn_repeated_rnn; mod rnn_clustered; +mod rnn_linear; -use distances::Number; -use rayon::prelude::*; +pub use knn_breadth_first::KnnBreadthFirst; +pub use knn_depth_first::KnnDepthFirst; +pub use knn_hinted::KnnHinted; +pub use knn_linear::KnnLinear; +pub use knn_repeated_rnn::KnnRepeatedRnn; +pub use rnn_clustered::RnnClustered; +pub use rnn_linear::RnnLinear; + +/// Common trait for entropy scaling search algorithms. +#[allow(clippy::module_name_repetitions)] +pub trait SearchAlgorithm, M: Metric, D: Searchable> { + /// Return the name of the search algorithm. + fn name(&self) -> &str; -use crate::{cluster::ParCluster, dataset::ParDataset, Cluster, Dataset}; - -/// The different algorithms that can be used for search. -/// -/// - `RnnClustered` - Ranged Nearest Neighbors search using the tree. -/// - `KnnRepeatedRnn` - K-Nearest Neighbors search using repeated `RnnClustered` searches. -/// - `KnnBreadthFirst` - K-Nearest Neighbors search using a breadth-first sieve. -/// - `KnnDepthFirst` - K-Nearest Neighbors search using a depth-first sieve. -/// -/// See the `CAKES` paper for more information on these algorithms. -#[derive(Clone, Copy)] -#[non_exhaustive] -pub enum Algorithm { - /// Linear RNN search. - RnnLinear(U), - /// Linear KNN search. - KnnLinear(usize), - /// Ranged Nearest Neighbors search using the tree. + /// Get the radius if it is a ranged search algorithm. + fn radius(&self) -> Option; + + /// Get the value of k if it is a k-NN search algorithm. + fn k(&self) -> Option; + + /// Perform a search using the given parameters. /// - /// # Parameters + /// # Arguments /// - /// - `U` - The radius to search within. - RnnClustered(U), - /// K-Nearest Neighbors search using repeated `RnnClustered` searches. + /// * `data` - The dataset to search. + /// * `metric` - The metric to use for distance calculations. + /// * `root` - The root of the tree to search. + /// * `query` - The query to search around. /// - /// # Parameters + /// # Returns /// - /// - `usize` - The number of neighbors to search for. - /// - `U` - The maximum multiplier for the radius when repeating the search. - KnnRepeatedRnn(usize, U), - /// K-Nearest Neighbors search using a breadth-first sieve. - KnnBreadthFirst(usize), - /// K-Nearest Neighbors search using a depth-first sieve. - KnnDepthFirst(usize), + /// A vector of pairs, where each pair contains the index of an item in the + /// dataset and the distance from the query to that item. + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)>; + + /// Batched version of `Search::search`. + fn batch_search(&self, data: &D, metric: &M, root: &C, queries: &[I]) -> Vec> { + queries + .iter() + .map(|query| self.search(data, metric, root, query)) + .collect() + } } -impl Algorithm { - /// Perform the search using the algorithm. - pub fn search, C: Cluster>(&self, data: &D, root: &C, query: &I) -> Vec<(usize, U)> { - match self { - Self::RnnLinear(radius) => data.rnn(query, *radius), - Self::KnnLinear(k) => data.knn(query, *k), - Self::RnnClustered(radius) => rnn_clustered::search(data, root, query, *radius), - Self::KnnRepeatedRnn(k, max_multiplier) => knn_repeated_rnn::search(data, root, query, *k, *max_multiplier), - Self::KnnBreadthFirst(k) => knn_breadth_first::search(data, root, query, *k), - Self::KnnDepthFirst(k) => knn_depth_first::search(data, root, query, *k), - } +/// Parallel version of [`SearchAlgorithm`](crate::cakes::search::SearchAlgorithm). +pub trait ParSearchAlgorithm< + I: Send + Sync, + T: Number, + C: ParCluster, + M: ParMetric, + D: ParSearchable, +>: SearchAlgorithm + Send + Sync +{ + /// Parallel version of [`SearchAlgorithm::search`](crate::cakes::search::SearchAlgorithm::search). + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)>; + + /// Parallel version of [`SearchAlgorithm::batch_search`](crate::cakes::search::SearchAlgorithm::batch_search). + fn par_batch_search(&self, data: &D, metric: &M, root: &C, queries: &[I]) -> Vec> { + queries + .par_iter() + .map(|query| self.par_search(data, metric, root, query)) + .collect() } +} - /// Parallel version of the `search` method. - pub fn par_search, C: ParCluster>( - &self, - data: &D, - root: &C, - query: &I, - ) -> Vec<(usize, U)> { - match self { - Self::RnnLinear(radius) => data.par_rnn(query, *radius), - Self::KnnLinear(k) => data.par_knn(query, *k), - Self::RnnClustered(radius) => rnn_clustered::par_search(data, root, query, *radius), - Self::KnnRepeatedRnn(k, max_multiplier) => { - knn_repeated_rnn::par_search(data, root, query, *k, *max_multiplier) - } - Self::KnnBreadthFirst(k) => knn_breadth_first::par_search(data, root, query, *k), - Self::KnnDepthFirst(k) => knn_depth_first::par_search(data, root, query, *k), - } +/// A blanket implementation of `SearchAlgorithm` for `Box`. +impl, M: Metric, D: Searchable> SearchAlgorithm + for Box> +{ + fn name(&self) -> &str { + self.as_ref().name() } - /// Batched version of the `search` method. - pub fn batch_search, C: Cluster>( - &self, - data: &D, - root: &C, - queries: &[I], - ) -> Vec> { - match self { - Self::RnnLinear(radius) => queries.iter().map(|query| data.rnn(query, *radius)).collect(), - Self::KnnLinear(k) => queries.iter().map(|query| data.knn(query, *k)).collect(), - Self::RnnClustered(radius) => queries - .iter() - .map(|query| rnn_clustered::search(data, root, query, *radius)) - .collect(), - Self::KnnRepeatedRnn(k, max_multiplier) => queries - .iter() - .map(|query| knn_repeated_rnn::search(data, root, query, *k, *max_multiplier)) - .collect(), - Self::KnnBreadthFirst(k) => queries - .iter() - .map(|query| knn_breadth_first::search(data, root, query, *k)) - .collect(), - Self::KnnDepthFirst(k) => queries - .iter() - .map(|query| knn_depth_first::search(data, root, query, *k)) - .collect(), - } + fn radius(&self) -> Option { + self.as_ref().radius() } - /// Parallel version of the `batch_search` method. - pub fn par_batch_search, C: ParCluster>( - &self, - data: &D, - root: &C, - queries: &[I], - ) -> Vec> { - match self { - Self::RnnLinear(radius) => queries.par_iter().map(|query| data.rnn(query, *radius)).collect(), - Self::KnnLinear(k) => queries.par_iter().map(|query| data.knn(query, *k)).collect(), - Self::RnnClustered(radius) => queries - .par_iter() - .map(|query| rnn_clustered::search(data, root, query, *radius)) - .collect(), - Self::KnnRepeatedRnn(k, max_multiplier) => queries - .par_iter() - .map(|query| knn_repeated_rnn::search(data, root, query, *k, *max_multiplier)) - .collect(), - Self::KnnBreadthFirst(k) => queries - .par_iter() - .map(|query| knn_breadth_first::search(data, root, query, *k)) - .collect(), - Self::KnnDepthFirst(k) => queries - .par_iter() - .map(|query| knn_depth_first::search(data, root, query, *k)) - .collect(), - } + fn k(&self) -> Option { + self.as_ref().k() } - /// Batched version of the `par_search` method. - pub fn batch_par_search, C: ParCluster>( - &self, - data: &D, - root: &C, - queries: &[I], - ) -> Vec> { - match self { - Self::RnnLinear(radius) => queries.iter().map(|query| data.par_rnn(query, *radius)).collect(), - Self::KnnLinear(k) => queries.iter().map(|query| data.par_knn(query, *k)).collect(), - Self::RnnClustered(radius) => queries - .iter() - .map(|query| rnn_clustered::par_search(data, root, query, *radius)) - .collect(), - Self::KnnRepeatedRnn(k, max_multiplier) => queries - .iter() - .map(|query| knn_repeated_rnn::par_search(data, root, query, *k, *max_multiplier)) - .collect(), - Self::KnnBreadthFirst(k) => queries - .iter() - .map(|query| knn_breadth_first::par_search(data, root, query, *k)) - .collect(), - Self::KnnDepthFirst(k) => queries - .iter() - .map(|query| knn_depth_first::par_search(data, root, query, *k)) - .collect(), - } + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + self.as_ref().search(data, metric, root, query) } +} - /// Parallel version of the `batch_par_search` method. - pub fn par_batch_par_search, C: ParCluster>( - &self, - data: &D, - root: &C, - queries: &[I], - ) -> Vec> { - match self { - Self::RnnLinear(radius) => queries.par_iter().map(|query| data.par_rnn(query, *radius)).collect(), - Self::KnnLinear(k) => queries.par_iter().map(|query| data.par_knn(query, *k)).collect(), - Self::RnnClustered(radius) => queries - .par_iter() - .map(|query| rnn_clustered::par_search(data, root, query, *radius)) - .collect(), - Self::KnnRepeatedRnn(k, max_multiplier) => queries - .par_iter() - .map(|query| knn_repeated_rnn::par_search(data, root, query, *k, *max_multiplier)) - .collect(), - Self::KnnBreadthFirst(k) => queries - .par_iter() - .map(|query| knn_breadth_first::par_search(data, root, query, *k)) - .collect(), - Self::KnnDepthFirst(k) => queries - .par_iter() - .map(|query| knn_depth_first::par_search(data, root, query, *k)) - .collect(), - } +/// A blanket implementation of `SearchAlgorithm` for `Box`. +impl, M: ParMetric, D: ParSearchable> + SearchAlgorithm for Box> +{ + fn name(&self) -> &str { + self.as_ref().name() } - /// Get the name of the algorithm. - pub fn name(&self) -> String { - match self { - Self::RnnLinear(r) => format!("RnnLinear({r})"), - Self::KnnLinear(k) => format!("KnnLinear({k})"), - Self::RnnClustered(r) => format!("RnnClustered({r})"), - Self::KnnRepeatedRnn(k, m) => format!("KnnRepeatedRnn({k}, {m})"), - Self::KnnBreadthFirst(k) => format!("KnnBreadthFirst({k})"), - Self::KnnDepthFirst(k) => format!("KnnDepthFirst({k})"), - } + fn radius(&self) -> Option { + self.as_ref().radius() } - /// Get the name of the variant of algorithm. - pub const fn variant_name(&self) -> &str { - match self { - Self::RnnLinear(_) => "RnnLinear", - Self::KnnLinear(_) => "KnnLinear", - Self::RnnClustered(_) => "RnnClustered", - Self::KnnRepeatedRnn(_, _) => "KnnRepeatedRnn", - Self::KnnBreadthFirst(_) => "KnnBreadthFirst", - Self::KnnDepthFirst(_) => "KnnDepthFirst", - } + fn k(&self) -> Option { + self.as_ref().k() } - /// Same variant of the algorithm with different parameters. - #[must_use] - pub const fn with_params(&self, radius: U, k: usize) -> Self { - match self { - Self::RnnLinear(_) => Self::RnnLinear(radius), - Self::KnnLinear(_) => Self::KnnLinear(k), - Self::RnnClustered(_) => Self::RnnClustered(radius), - Self::KnnRepeatedRnn(_, m) => Self::KnnRepeatedRnn(k, *m), - Self::KnnBreadthFirst(_) => Self::KnnBreadthFirst(k), - Self::KnnDepthFirst(_) => Self::KnnDepthFirst(k), - } + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + self.as_ref().search(data, metric, root, query) } +} - /// Returns the linear-search variant of the algorithm. - #[must_use] - pub const fn linear_variant(&self) -> Self { - match self { - Self::RnnClustered(r) | Self::RnnLinear(r) => Self::RnnLinear(*r), - Self::KnnBreadthFirst(k) | Self::KnnDepthFirst(k) | Self::KnnRepeatedRnn(k, _) | Self::KnnLinear(k) => { - Self::KnnLinear(*k) - } - } +/// A blanket implementation of `ParSearchAlgorithm` for `Box`. +impl, M: ParMetric, D: ParSearchable> + ParSearchAlgorithm for Box> +{ + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + self.as_ref().par_search(data, metric, root, query) } } diff --git a/crates/abd-clam/src/cakes/search/rnn_clustered.rs b/crates/abd-clam/src/cakes/search/rnn_clustered.rs index 0d436f8e7..3f50266fa 100644 --- a/crates/abd-clam/src/cakes/search/rnn_clustered.rs +++ b/crates/abd-clam/src/cakes/search/rnn_clustered.rs @@ -1,31 +1,49 @@ -//! Ranged Nearest Neighbors search using the tree. +//! Ranged Nearest Neighbors search using a tree, as described in the CHESS +//! paper. use distances::Number; use rayon::prelude::*; -use crate::{cluster::ParCluster, dataset::ParDataset, Cluster, Dataset}; +use crate::{ + cakes::{ParSearchable, Searchable}, + cluster::ParCluster, + metric::ParMetric, + Cluster, Metric, +}; -/// Clustered search for the ranged nearest neighbors of a query. -pub fn search(data: &D, root: &C, query: &I, radius: U) -> Vec<(usize, U)> -where - U: Number, - D: Dataset, - C: Cluster, +use super::{ParSearchAlgorithm, RnnLinear, SearchAlgorithm}; + +/// Ranged Nearest Neighbors search using a tree. +pub struct RnnClustered(pub T); + +impl, M: Metric, D: Searchable> SearchAlgorithm + for RnnClustered { - let [confirmed, straddlers] = tree_search(data, root, query, radius); - leaf_search(data, confirmed, straddlers, query, radius) + fn name(&self) -> &str { + "RnnClustered" + } + + fn radius(&self) -> Option { + Some(self.0) + } + + fn k(&self) -> Option { + None + } + + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + let [confirmed, straddlers] = tree_search(data, metric, root, query, self.0); + leaf_search(data, metric, confirmed, straddlers, query, self.0) + } } -/// Parallel clustered search for the ranged nearest neighbors of a query. -pub fn par_search(data: &D, root: &C, query: &I, radius: U) -> Vec<(usize, U)> -where - I: Send + Sync, - U: Number, - D: ParDataset, - C: ParCluster, +impl, M: ParMetric, D: ParSearchable> + ParSearchAlgorithm for RnnClustered { - let [confirmed, straddlers] = par_tree_search(data, root, query, radius); - par_leaf_search(data, confirmed, straddlers, query, radius) + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + let [confirmed, straddlers] = par_tree_search(data, metric, root, query, self.0); + par_leaf_search(data, metric, confirmed, straddlers, query, self.0) + } } /// Perform coarse-grained tree search. @@ -44,59 +62,82 @@ where /// query ball, and the second element is the straddlers, i.e. those that /// overlap the query ball. The 2-tuples are the clusters and the distance /// from the query to the cluster center. -pub fn tree_search<'a, I, U, D, C>(data: &D, root: &'a C, query: &I, radius: U) -> [Vec<(&'a C, U)>; 2] +#[inline(never)] +pub fn tree_search<'a, I, T, C, M, D>(data: &D, metric: &M, root: &'a C, query: &I, radius: T) -> [Vec<(&'a C, T)>; 2] where - U: Number + 'a, - D: Dataset, - C: Cluster, + T: Number + 'a, + C: Cluster, + M: Metric, + D: Searchable, { + let (overlap, distances) = root.overlaps_with(data, metric, query, radius); + if !overlap { + return [Vec::new(), Vec::new()]; + } + let mut confirmed = Vec::new(); let mut straddlers = Vec::new(); - let mut candidates = vec![root]; + let mut candidates = vec![(root, distances[0])]; - let (mut terminal, mut non_terminal): (Vec<_>, Vec<_>); while !candidates.is_empty() { - (terminal, non_terminal) = candidates - .into_iter() - .map(|c| (c, c.distance_to_center(data, query))) - .filter(|&(c, d)| d <= (c.radius() + radius)) - .partition(|&(c, d)| (c.radius() + d) <= radius); - confirmed.append(&mut terminal); - - (terminal, non_terminal) = non_terminal.into_iter().partition(|&(c, _)| c.is_leaf()); - straddlers.append(&mut terminal); - - candidates = non_terminal.into_iter().flat_map(|(c, _)| c.child_clusters()).collect(); + candidates = candidates.into_iter().fold(Vec::new(), |mut next_candidates, (c, d)| { + if (c.radius() + d) <= radius { + confirmed.push((c, d)); + } else if c.is_leaf() { + straddlers.push((c, d)); + } else { + next_candidates.extend( + c.overlapping_children(data, metric, query, radius) + .into_iter() + .map(|(c, ds)| (c, ds[0])), + ); + } + next_candidates + }); } [confirmed, straddlers] } -/// Parallelized version of the tree search. -pub fn par_tree_search<'a, I, U, D, C>(data: &D, root: &'a C, query: &I, radius: U) -> [Vec<(&'a C, U)>; 2] +/// Parallel version of [`tree_search`](crate::cakes::search::rnn_clustered::tree_search). +pub fn par_tree_search<'a, I, T, C, M, D>( + data: &D, + metric: &M, + root: &'a C, + query: &I, + radius: T, +) -> [Vec<(&'a C, T)>; 2] where I: Send + Sync, - U: Number + 'a, - D: ParDataset, - C: ParCluster, + T: Number + 'a, + C: ParCluster, + M: ParMetric, + D: ParSearchable, { + let (overlap, distances) = root.par_overlaps_with(data, metric, query, radius); + if !overlap { + return [Vec::new(), Vec::new()]; + } + let mut confirmed = Vec::new(); let mut straddlers = Vec::new(); - let mut candidates = vec![root]; + let mut candidates = vec![(root, distances[0])]; - let (mut terminal, mut non_terminal): (Vec<_>, Vec<_>); while !candidates.is_empty() { - (terminal, non_terminal) = candidates - .into_par_iter() - .map(|c| (c, c.distance_to_center(data, query))) - .filter(|&(c, d)| d <= (c.radius() + radius)) - .partition(|&(c, d)| (c.radius() + d) < radius); - confirmed.append(&mut terminal); - - (terminal, non_terminal) = non_terminal.into_iter().partition(|&(c, _)| c.is_leaf()); - straddlers.append(&mut terminal); - - candidates = non_terminal.into_iter().flat_map(|(c, _)| c.child_clusters()).collect(); + candidates = candidates.into_iter().fold(Vec::new(), |mut next_candidates, (c, d)| { + if (c.radius() + d) <= radius { + confirmed.push((c, d)); + } else if c.is_leaf() { + straddlers.push((c, d)); + } else { + next_candidates.extend( + c.par_overlapping_children(data, metric, query, radius) + .into_iter() + .map(|(c, ds)| (c, ds[0])), + ); + } + next_candidates + }); } [confirmed, straddlers] @@ -118,140 +159,67 @@ where /// # Returns /// /// The `(index, distance)` pairs of the points within the query ball. -pub fn leaf_search( +#[inline(never)] +pub fn leaf_search( data: &D, - confirmed: Vec<(&C, U)>, - straddlers: Vec<(&C, U)>, + metric: &M, + confirmed: Vec<(&C, T)>, + straddlers: Vec<(&C, T)>, query: &I, - radius: U, -) -> Vec<(usize, U)> + radius: T, +) -> Vec<(usize, T)> where - U: Number, - D: Dataset, - C: Cluster, + T: Number, + C: Cluster, + M: Metric, + D: Searchable, { - let hits = confirmed.into_iter().flat_map(|(c, d)| { - if c.is_singleton() { - c.indices().map(|i| (i, d)).collect() - } else { - c.distances_to_query(data, query) - } - }); - - let distances = straddlers + confirmed .into_iter() - .flat_map(|(c, _)| c.distances_to_query(data, query)) - .filter(|&(_, d)| d <= radius); - - hits.chain(distances).collect() + .flat_map(|(c, d)| { + if c.is_singleton() { + c.indices().into_iter().map(|i| (i, d)).collect::>() + } else { + data.query_to_all(metric, query, c).collect() + } + }) + .chain( + straddlers + .into_iter() + .flat_map(|(c, _)| RnnLinear(radius).search(data, metric, c, query)), + ) + .collect() } -/// Parallelized version of the leaf search. -pub fn par_leaf_search( +/// Parallel version of [`leaf_search`](crate::cakes::search::rnn_clustered::leaf_search). +pub fn par_leaf_search( data: &D, - confirmed: Vec<(&C, U)>, - straddlers: Vec<(&C, U)>, + metric: &M, + confirmed: Vec<(&C, T)>, + straddlers: Vec<(&C, T)>, query: &I, - radius: U, -) -> Vec<(usize, U)> + radius: T, +) -> Vec<(usize, T)> where I: Send + Sync, - U: Number, - D: ParDataset, - C: ParCluster, + T: Number, + C: ParCluster, + M: ParMetric, + D: ParSearchable, { - let hits = confirmed.into_par_iter().flat_map(|(c, d)| { - if c.is_singleton() { - c.indices().map(|i| (i, d)).collect() - } else { - c.par_distances_to_query(data, query) - } - }); - - let distances = straddlers + confirmed .into_par_iter() - .flat_map(|(c, _)| c.par_distances_to_query(data, query)) - .filter(|&(_, d)| d <= radius); - - hits.chain(distances).collect() -} - -#[cfg(test)] -mod tests { - use core::fmt::Debug; - - use distances::Number; - - use crate::Dataset; - use crate::{ - adapter::BallAdapter, cakes::OffBall, cluster::ParCluster, partition::ParPartition, Ball, Cluster, FlatVec, - Partition, - }; - - use crate::cakes::tests::{check_search_by_index, gen_grid_data, gen_line_data}; - - pub fn check_rnn>>( - root: &C, - data: &FlatVec, - query: &I, - radius: U, - ) -> bool { - let true_hits = data.rnn(query, radius); - - let pred_hits = super::search(data, root, query, radius); - assert_eq!(pred_hits.len(), true_hits.len(), "Rnn search failed: {pred_hits:?}"); - check_search_by_index(true_hits.clone(), pred_hits, "RnnClustered", data); - - let pred_hits = super::par_search(data, root, query, radius); - assert_eq!( - pred_hits.len(), - true_hits.len(), - "Parallel Rnn search failed: {pred_hits:?}" - ); - check_search_by_index(true_hits, pred_hits, "Par RnnClustered", data); - - true - } - - #[test] - fn line() -> Result<(), String> { - let data = gen_line_data(10)?; - let query = &0; - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::new_tree(&data, &criteria, seed); - for radius in 0..=4 { - assert!(check_rnn(&ball, &data, &query, radius)); - } - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball, data); - for radius in 0..=4 { - assert!(check_rnn(&off_ball, &perm_data, &query, radius)); - } - - Ok(()) - } - - #[test] - fn grid() -> Result<(), String> { - let data = gen_grid_data(10)?; - let query = &(0.0, 0.0); - - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let seed = Some(42); - - let ball = Ball::par_new_tree(&data, &criteria, seed); - for radius in [1.0, 4.0, 8.0, 16.0, 32.0] { - assert!(check_rnn(&ball, &data, &query, radius)); - } - - let (off_ball, perm_data) = OffBall::from_ball_tree(ball, data); - for radius in [1.0, 4.0, 8.0, 16.0, 32.0] { - assert!(check_rnn(&off_ball, &perm_data, &query, radius)); - } - - Ok(()) - } + .flat_map(|(c, d)| { + if c.is_singleton() { + c.indices().into_iter().map(|i| (i, d)).collect::>() + } else { + data.par_query_to_all(metric, query, c).collect() + } + }) + .chain( + straddlers + .into_par_iter() + .flat_map(|(c, _)| RnnLinear(radius).par_search(data, metric, c, query)), + ) + .collect() } diff --git a/crates/abd-clam/src/cakes/search/rnn_linear.rs b/crates/abd-clam/src/cakes/search/rnn_linear.rs new file mode 100644 index 000000000..da2789052 --- /dev/null +++ b/crates/abd-clam/src/cakes/search/rnn_linear.rs @@ -0,0 +1,48 @@ +//! Ranged nearest neighbor search using a linear scan of the dataset. + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{ParSearchable, Searchable}, + cluster::ParCluster, + metric::ParMetric, + Cluster, Metric, +}; + +use super::{ParSearchAlgorithm, SearchAlgorithm}; + +/// Ranged nearest neighbor search using a linear scan of the dataset. +pub struct RnnLinear(pub T); + +impl, M: Metric, D: Searchable> SearchAlgorithm + for RnnLinear +{ + fn name(&self) -> &str { + "RnnLinear" + } + + fn radius(&self) -> Option { + Some(self.0) + } + + fn k(&self) -> Option { + None + } + + fn search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + data.query_to_all(metric, query, root) + .filter(|&(_, d)| d <= self.0) + .collect() + } +} + +impl, M: ParMetric, D: ParSearchable> + ParSearchAlgorithm for RnnLinear +{ + fn par_search(&self, data: &D, metric: &M, root: &C, query: &I) -> Vec<(usize, T)> { + data.par_query_to_all(metric, query, root) + .filter(|&(_, d)| d <= self.0) + .collect() + } +} diff --git a/crates/abd-clam/src/chaoda/cluster/vertex.rs b/crates/abd-clam/src/chaoda/cluster/vertex.rs index 24c61919d..a585cfe7c 100644 --- a/crates/abd-clam/src/chaoda/cluster/vertex.rs +++ b/crates/abd-clam/src/chaoda/cluster/vertex.rs @@ -1,18 +1,18 @@ //! The `Vertex` is a `Cluster` adapter used to represent `Cluster`s in the //! `Graph`. -use core::marker::PhantomData; - use distances::Number; use rayon::prelude::*; -use serde::{Deserialize, Serialize}; use crate::{ - adapter::{Adapter, BallAdapter, ParAdapter, ParBallAdapter, ParParams, Params}, chaoda::NUM_RATIOS, - cluster::ParCluster, + cluster::{ + adapter::{Adapter, BallAdapter, ParAdapter, ParBallAdapter, ParParams, Params}, + ParCluster, + }, dataset::ParDataset, - Ball, Cluster, Dataset, + metric::ParMetric, + Ball, Cluster, Dataset, Metric, }; /// The `Vertex` is a `Cluster` adapter used to represent `Cluster`s in the @@ -20,105 +20,81 @@ use crate::{ /// /// # Type Parameters /// -/// - `I`: The type on instances in the dataset. -/// - `U`: The type of the distance values. -/// - `D`: The type of the dataset. +/// - `T`: The type of the distance values. /// - `S`: The type of the `Cluster` that was adapted into the `Vertex`. -#[derive(Serialize, Deserialize)] -pub struct Vertex, S: Cluster> { +#[derive(Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +#[cfg_attr(feature = "disk-io", bitcode(recursive))] +pub struct Vertex> { /// The `Cluster` that was adapted into the `Vertex`. - source: S, + pub(crate) source: S, /// The children of the `Vertex`. - children: Vec<(usize, U, Box)>, + children: Vec>, /// The anomaly detection properties of the `Vertex`. params: Ratios, - /// Phantom data to satisfy the compiler. - _id: PhantomData<(I, D)>, + /// Ghosts in the machine. + phantom: core::marker::PhantomData, } -impl, S: Cluster + std::fmt::Debug> std::fmt::Debug for Vertex { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl + core::fmt::Debug> core::fmt::Debug for Vertex { + fn fmt(&self, f: &mut core::fmt::Formatter) -> std::fmt::Result { f.debug_struct("Vertex") .field("source", &self.source) - .field("children", &self.children) .field("ratios", &self.params.ratios) .field("ema_ratios", &self.params.ema_ratios) .field("accumulated_cp_car_ratio", &self.params.accumulated_cp_car_ratio) + .field("children", &!self.children.is_empty()) .finish() } } -impl, S: Cluster> Vertex { - /// Returns the anomaly detection properties of the `Vertex` and their - /// exponential moving averages. - #[must_use] - pub const fn ratios(&self) -> [f32; NUM_RATIOS] { - let [c, r, l] = self.params.ratios; - let [c_, r_, l_] = self.params.ema_ratios; - [c, r, l, c_, r_, l_] - } - - /// Returns the accumulated child-parent cardinality ratio. - #[must_use] - pub const fn accumulated_cp_car_ratio(&self) -> f32 { - self.params.accumulated_cp_car_ratio +impl> PartialEq for Vertex { + fn eq(&self, other: &Self) -> bool { + self.source == other.source } } -impl> BallAdapter for Vertex> { - /// Creates a new `OffsetBall` tree from a `Ball` tree. - fn from_ball_tree(ball: Ball, data: D) -> (Self, D) { - let root = Self::adapt_tree(ball, None); - (root, data) - } -} +impl> Eq for Vertex {} -impl> ParBallAdapter - for Vertex> -{ - /// Creates a new `OffsetBall` tree from a `Ball` tree. - fn par_from_ball_tree(ball: Ball, data: D) -> (Self, D) { - let root = Self::par_adapt_tree(ball, None); - (root, data) +impl> PartialOrd for Vertex { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } -impl, S: Cluster> Adapter for Vertex { - fn new_adapted(source: S, children: Vec<(usize, U, Box)>, params: Ratios) -> Self { - Self { - source, - children, - params, - _id: PhantomData, - } - } - - fn post_traversal(&mut self) {} - - fn source(&self) -> &S { - &self.source +impl> Ord for Vertex { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.source.cmp(&other.source) } +} - fn source_mut(&mut self) -> &mut S { - &mut self.source +impl> core::hash::Hash for Vertex { + fn hash(&self, state: &mut H) { + self.source.hash(state); } +} - fn take_source(self) -> S { - self.source +impl> Vertex { + /// Returns the anomaly detection properties of the `Vertex` and their + /// exponential moving averages. + #[must_use] + pub const fn ratios(&self) -> [f32; NUM_RATIOS] { + let [c, r, l] = self.params.ratios; + let [c_, r_, l_] = self.params.ema_ratios; + [c, r, l, c_, r_, l_] } - fn params(&self) -> &Ratios { - &self.params + /// Returns the accumulated child-parent cardinality ratio. + #[must_use] + pub const fn accumulated_cp_car_ratio(&self) -> f32 { + self.params.accumulated_cp_car_ratio } } -impl, S: ParCluster> ParAdapter - for Vertex -{ - fn par_post_traversal(&mut self) {} -} - -impl, S: Cluster> Cluster for Vertex { +impl> Cluster for Vertex { fn depth(&self) -> usize { self.source.depth() } @@ -135,7 +111,7 @@ impl, S: Cluster> Cluster for V self.source.set_arg_center(arg_center); } - fn radius(&self) -> U { + fn radius(&self) -> T { self.source.radius() } @@ -151,75 +127,69 @@ impl, S: Cluster> Cluster for V self.source.lfd() } - fn indices(&self) -> impl Iterator + '_ { - self.source.indices() + fn contains(&self, index: usize) -> bool { + self.source.contains(index) } - fn set_indices(&mut self, indices: Vec) { - self.source.set_indices(indices); + fn indices(&self) -> Vec { + self.source.indices() } - fn children(&self) -> &[(usize, U, Box)] { - &self.children + fn set_indices(&mut self, indices: &[usize]) { + self.source.set_indices(indices); } - fn children_mut(&mut self) -> &mut [(usize, U, Box)] { - &mut self.children + fn extents(&self) -> &[(usize, T)] { + self.source.extents() } - fn set_children(&mut self, children: Vec<(usize, U, Box)>) { - self.children = children; + fn extents_mut(&mut self) -> &mut [(usize, T)] { + self.source.extents_mut() } - fn take_children(&mut self) -> Vec<(usize, U, Box)> { - core::mem::take(&mut self.children) + fn add_extent(&mut self, idx: usize, extent: T) { + self.source.add_extent(idx, extent); } - fn distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)> { - self.source.distances_to_query(data, query) + fn take_extents(&mut self) -> Vec<(usize, T)> { + self.source.take_extents() } - fn is_descendant_of(&self, other: &Self) -> bool { - self.source.is_descendant_of(&other.source) + fn children(&self) -> Vec<&Self> { + self.children.iter().map(AsRef::as_ref).collect() } -} -impl, S: ParCluster> ParCluster for Vertex { - fn par_distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)> { - self.source.par_distances_to_query(data, query) + fn children_mut(&mut self) -> Vec<&mut Self> { + self.children.iter_mut().map(AsMut::as_mut).collect() } -} -impl, S: Cluster> PartialEq for Vertex { - fn eq(&self, other: &Self) -> bool { - self.source == other.source + fn set_children(&mut self, children: Vec>) { + self.children = children; } -} -impl, S: Cluster> Eq for Vertex {} - -impl, S: Cluster> PartialOrd for Vertex { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + fn take_children(&mut self) -> Vec> { + core::mem::take(&mut self.children) } -} -impl, S: Cluster> Ord for Vertex { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.source.cmp(&other.source) + fn is_descendant_of(&self, other: &Self) -> bool { + self.source.is_descendant_of(&other.source) } } -impl, S: Cluster> std::hash::Hash for Vertex { - fn hash(&self, state: &mut H) { - self.source.hash(state); +impl> ParCluster for Vertex { + fn par_indices(&self) -> impl ParallelIterator { + self.source.par_indices() } } /// The anomaly detection properties of the `Vertex`, their exponential moving /// averages, and the accumulated child-parent cardinality ratio. #[allow(clippy::struct_field_names)] -#[derive(Serialize, Deserialize)] +#[derive(Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] pub struct Ratios { /// The anomaly detection properties of the `Vertex`. ratios: [f32; 3], @@ -239,14 +209,14 @@ impl Default for Ratios { } } -impl, S: Cluster> Params for Ratios { - fn child_params(&self, children: &[S]) -> Vec { +impl, S: Cluster> Params for Ratios { + fn child_params>(&self, children: &[S], _: &D, _: &M) -> Vec { children.iter().map(|child| child_params(self, child)).collect() } } -impl, S: ParCluster> ParParams for Ratios { - fn par_child_params(&self, children: &[S]) -> Vec { +impl, S: ParCluster> ParParams for Ratios { + fn par_child_params>(&self, children: &[S], _: &D, _: &M) -> Vec { children.par_iter().map(|child| child_params(self, child)).collect() } } @@ -254,7 +224,7 @@ impl, S: ParCluster> Par /// Computes the anomaly detection properties of a child `Cluster` given the /// anomaly detection properties of the parent `Cluster`. #[allow(clippy::similar_names)] -fn child_params, C: Cluster>(parent: &Ratios, child: &C) -> Ratios { +fn child_params>(parent: &Ratios, child: &C) -> Ratios { let [pc, pr, pl] = parent.ratios; let c = child.cardinality().as_f32() / pc; let r = child.radius().as_f32() / pr; @@ -274,3 +244,60 @@ fn child_params, C: Cluster>(parent: &Ra accumulated_cp_car_ratio, } } + +impl> BallAdapter for Vertex> { + /// Creates a new `OffsetBall` tree from a `Ball` tree. + fn from_ball_tree>(ball: Ball, data: D, metric: &M) -> (Self, D) { + let root = Self::adapt_tree(ball, None, &data, metric); + (root, data) + } +} + +impl> ParBallAdapter for Vertex> { + /// Creates a new `OffsetBall` tree from a `Ball` tree. + fn par_from_ball_tree>(ball: Ball, data: D, metric: &M) -> (Self, D) { + let root = Self::par_adapt_tree(ball, None, &data, metric); + (root, data) + } +} + +impl, S: Cluster> Adapter for Vertex { + fn new_adapted>(source: S, children: Vec>, params: Ratios, _: &D, _: &M) -> Self { + Self { + source, + params, + children, + phantom: core::marker::PhantomData, + } + } + + fn post_traversal(&mut self) {} + + fn source(&self) -> &S { + &self.source + } + + fn source_mut(&mut self) -> &mut S { + &mut self.source + } + + fn take_source(self) -> S { + self.source + } + + fn params(&self) -> &Ratios { + &self.params + } +} + +impl, S: ParCluster> ParAdapter for Vertex { + fn par_new_adapted>( + source: S, + children: Vec>, + params: Ratios, + data: &D, + metric: &M, + ) -> Self { + Self::new_adapted(source, children, params, data, metric) + } +} diff --git a/crates/abd-clam/src/chaoda/graph/adjacency_list.rs b/crates/abd-clam/src/chaoda/graph/adjacency_list.rs index 400dbf1df..878862ae2 100644 --- a/crates/abd-clam/src/chaoda/graph/adjacency_list.rs +++ b/crates/abd-clam/src/chaoda/graph/adjacency_list.rs @@ -6,25 +6,20 @@ use std::collections::{HashMap, HashSet}; use distances::Number; use rayon::prelude::*; -use crate::{cluster::ParCluster, dataset::ParDataset, Cluster, Dataset}; +use crate::{cluster::ParCluster, dataset::ParDataset, metric::ParMetric, Cluster, Dataset, Metric}; /// An `AdjacencyList` is a map from a `Cluster` to a map from each neighbor to /// the distance between them. -pub struct AdjacencyList<'a, I, U: Number, D: Dataset, C: Cluster> { - /// The inner `HashMap` of the `AdjacencyList`. - inner: HashMap<&'a C, HashMap<&'a C, U>>, - /// Phantom data to keep track of the types. - _phantom: std::marker::PhantomData<(I, D)>, -} +pub struct AdjacencyList<'a, T: Number, C: Cluster>(HashMap<&'a C, HashMap<&'a C, T>>); -impl<'a, I, U: Number, D: Dataset, C: Cluster> AdjacencyList<'a, I, U, D, C> { +impl<'a, T: Number, C: Cluster> AdjacencyList<'a, T, C> { /// Create new `AdjacencyList`s for each `Component` in a `Graph`. /// /// # Arguments /// /// * data: The `Dataset` that the `Cluster`s are based on. /// * clusters: The `Cluster`s to create the `AdjacencyList` from. - pub fn new(clusters: &[&'a C], data: &D) -> Vec { + pub fn new, M: Metric>(clusters: &[&'a C], data: &D, metric: &M) -> Vec { // TODO: This is a naive implementation of creating an adjacency list. // We can improve this by using the search functionality from CAKES. let inner = clusters @@ -37,7 +32,7 @@ impl<'a, I, U: Number, D: Dataset, C: Cluster> AdjacencyList<'a, .filter(|&(j, _)| i != j) .filter_map(|(_, &v)| { let (ru, rv) = (u.radius(), v.radius()); - let d = u.distance_to_other(data, v); + let d = data.one_to_one(u.arg_center(), v.arg_center(), metric); if d <= ru + rv { Some((v, d)) } else { @@ -49,14 +44,10 @@ impl<'a, I, U: Number, D: Dataset, C: Cluster> AdjacencyList<'a, }) .collect(); - let c = Self { - inner, - _phantom: std::marker::PhantomData, - }; - - let [mut c, mut other] = c.partition(); + let adjacency_list = Self(inner); + let [mut c, mut other] = adjacency_list.partition(); let mut components = vec![c]; - while !other.inner.is_empty() { + while !other.0.is_empty() { [c, other] = other.partition(); components.push(c); } @@ -76,7 +67,7 @@ impl<'a, I, U: Number, D: Dataset, C: Cluster> AdjacencyList<'a, let mut visited = HashSet::new(); let start = self - .inner + .0 .keys() .next() .copied() @@ -94,37 +85,33 @@ impl<'a, I, U: Number, D: Dataset, C: Cluster> AdjacencyList<'a, // Add the neighbors of the `Cluster` to the stack. stack.extend( - self.inner[&u] + self.0[&u] .iter() .filter(|(&v, _)| !visited.contains(v)) .map(|(&v, _)| v), ); } - let (inner, other): (HashMap<_, _>, HashMap<_, _>) = - self.inner.into_iter().partition(|(u, _)| visited.contains(u)); + let (inner, other): (HashMap<_, _>, HashMap<_, _>) = self.0.into_iter().partition(|(u, _)| visited.contains(u)); - self.inner = inner; - let other = Self { - inner: other, - _phantom: std::marker::PhantomData, - }; + self.0 = inner; + let other = Self(other); [self, other] } /// Get the inner `HashMap`. - pub const fn inner(&self) -> &HashMap<&C, HashMap<&C, U>> { - &self.inner + pub const fn inner(&self) -> &HashMap<&C, HashMap<&C, T>> { + &self.0 } /// Compute the transition probability matrix of the `AdjacencyList`. pub fn transition_matrix(&self) -> ndarray::Array2 { - let n = self.inner.len(); + let n = self.0.len(); // Compute the (flattened) transition matrix. let mut transition_matrix = vec![0_f32; n * n]; - for (i, (_, neighbors)) in self.inner.iter().enumerate() { + for (i, (_, neighbors)) in self.0.iter().enumerate() { for (j, (_, d)) in neighbors.iter().enumerate() { transition_matrix[i * n + j] = d.as_f32().recip(); } @@ -145,20 +132,24 @@ impl<'a, I, U: Number, D: Dataset, C: Cluster> AdjacencyList<'a, /// Iterate over the `Cluster`s in the `AdjacencyList`. pub fn clusters(&self) -> Vec<&'a C> { - self.inner.keys().copied().collect() + self.0.keys().copied().collect() } /// Iterate over the edges in the `AdjacencyList`. - pub fn iter_edges(&self) -> impl Iterator + '_ { - self.inner + pub fn iter_edges(&self) -> impl Iterator + '_ { + self.0 .iter() .flat_map(|(&u, neighbors)| neighbors.iter().map(move |(&v, d)| (u, v, *d))) } } -impl<'a, I: Send + Sync, U: Number, D: ParDataset, C: ParCluster> AdjacencyList<'a, I, U, D, C> { - /// Parallel version of `new`. - pub fn par_new(clusters: &[&'a C], data: &D) -> Vec { +impl<'a, T: Number, C: ParCluster> AdjacencyList<'a, T, C> { + /// Parallel version of [`AdjacencyList::new`](crate::chaoda::graph::adjacency_list::AdjacencyList::new). + pub fn par_new, M: ParMetric>( + clusters: &[&'a C], + data: &D, + metric: &M, + ) -> Vec { // TODO: This is a naive implementation of creating an adjacency list. // We can improve this by using the search functionality from CAKES. let inner = clusters @@ -171,7 +162,7 @@ impl<'a, I: Send + Sync, U: Number, D: ParDataset, C: ParCluster> .filter(|&(j, _)| i != j) .filter_map(|(_, &v)| { let (ru, rv) = (u.radius(), v.radius()); - let d = u.distance_to_other(data, v); + let d = data.par_one_to_one(u.arg_center(), v.arg_center(), metric); if d <= ru + rv { Some((v, d)) } else { @@ -183,14 +174,10 @@ impl<'a, I: Send + Sync, U: Number, D: ParDataset, C: ParCluster> }) .collect(); - let c = Self { - inner, - _phantom: std::marker::PhantomData, - }; - - let [mut c, mut other] = c.partition(); + let adjacency_list = Self(inner); + let [mut c, mut other] = adjacency_list.partition(); let mut components = vec![c]; - while !other.inner.is_empty() { + while !other.0.is_empty() { [c, other] = other.partition(); components.push(c); } diff --git a/crates/abd-clam/src/chaoda/graph/component.rs b/crates/abd-clam/src/chaoda/graph/component.rs index 6e379db59..a45aa8b3b 100644 --- a/crates/abd-clam/src/chaoda/graph/component.rs +++ b/crates/abd-clam/src/chaoda/graph/component.rs @@ -5,21 +5,19 @@ use std::collections::HashMap; use distances::Number; use rayon::prelude::*; -use crate::{cluster::ParCluster, dataset::ParDataset, Cluster, Dataset}; +use crate::{cluster::ParCluster, Cluster}; -use super::adjacency_list::AdjacencyList; -use super::node::Node; -use super::Vertex; +use super::{adjacency_list::AdjacencyList, node::Node, Vertex}; /// A `Component` is a collection of `Node`s that are connected by edges in a /// `Graph`. Every `Node` in a `Component` is reachable from every other `Node` /// via a path along edges. -pub struct Component<'a, I, U: Number, D: Dataset, S: Cluster> { +pub struct Component<'a, T: Number, S: Cluster> { /// A map from each `Vertex` to the `Node` that represents it. #[allow(clippy::type_complexity)] - node_map: HashMap<&'a Vertex, Node<'a, I, U, D, S>>, + node_map: HashMap<&'a Vertex, Node<'a, T, S>>, /// The `AdjacencyList` of the `Component`. - adjacency_list: AdjacencyList<'a, I, U, D, Vertex>, + adjacency_list: AdjacencyList<'a, T, Vertex>, // adjacency_list: HashMap, HashMap<&'a Node<'a, I, U, D, S>, U>>, /// Diameter of the `Component`, i.e. the maximum eccentricity of its `Node`s. diameter: usize, @@ -27,14 +25,14 @@ pub struct Component<'a, I, U: Number, D: Dataset, S: Cluster> { population: usize, } -impl<'a, I, U: Number, D: Dataset, S: Cluster> Component<'a, I, U, D, S> { +impl<'a, T: Number, S: Cluster> Component<'a, T, S> { /// Create a new `Component` from a `Vec` of `Vertex`es and the `AdjacencyList` /// of the `Graph`. /// /// # Arguments /// /// * `adjacency_list`: The `AdjacencyList` of the `Graph`. - pub fn new(adjacency_list: AdjacencyList<'a, I, U, D, Vertex>) -> Self { + pub fn new(adjacency_list: AdjacencyList<'a, T, Vertex>) -> Self { let node_map = adjacency_list .clusters() .into_iter() @@ -74,17 +72,17 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Component<'a, I, U } /// Iterate over the `Vertex`es in the `Component`. - pub fn iter_vertices(&self) -> impl Iterator> + '_ { + pub fn iter_vertices(&self) -> impl Iterator> + '_ { self.node_map.keys().copied() } /// Iterate over the edges in the `Component`. - pub fn iter_edges(&self) -> impl Iterator, &Vertex, U)> + '_ { + pub fn iter_edges(&self) -> impl Iterator, &Vertex, T)> + '_ { self.adjacency_list.iter_edges() } /// Iterate over the lists of neighbors of the `Node`s in the `Component`. - pub fn iter_neighbors(&self) -> impl Iterator, U>> + '_ { + pub fn iter_neighbors(&self) -> impl Iterator, T>> + '_ { self.adjacency_list.inner().values() } @@ -139,9 +137,9 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Component<'a, I, U } } -impl<'a, I: Send + Sync, U: Number, D: ParDataset, S: ParCluster> Component<'a, I, U, D, S> { - /// Parallel version of `Component::new`. - pub fn par_new(adjacency_list: AdjacencyList<'a, I, U, D, Vertex>) -> Self { +impl<'a, T: Number, S: ParCluster> Component<'a, T, S> { + /// Parallel version of [`Component::new`](crate::chaoda::graph::component::Component::new). + pub fn par_new(adjacency_list: AdjacencyList<'a, T, Vertex>) -> Self { let node_map = adjacency_list .clusters() .into_par_iter() diff --git a/crates/abd-clam/src/chaoda/graph/mod.rs b/crates/abd-clam/src/chaoda/graph/mod.rs index d615588da..db0140479 100644 --- a/crates/abd-clam/src/chaoda/graph/mod.rs +++ b/crates/abd-clam/src/chaoda/graph/mod.rs @@ -3,13 +3,12 @@ use core::cmp::Reverse; -use std::collections::{BinaryHeap, HashMap}; +use std::collections::HashMap; use distances::Number; -use ordered_float::OrderedFloat; use rayon::prelude::*; -use crate::{cluster::ParCluster, dataset::ParDataset, Cluster, Dataset}; +use crate::{cluster::ParCluster, dataset::ParDataset, metric::ParMetric, Cluster, Dataset, Metric, SizedHeap}; use super::Vertex; @@ -20,9 +19,9 @@ mod node; pub use component::Component; /// A `Graph` is a collection of `Vertex`es. -pub struct Graph<'a, I, U: Number, D: Dataset, S: Cluster> { +pub struct Graph<'a, T: Number, S: Cluster> { /// The collection of `Component`s in the `Graph`. - components: Vec>, + components: Vec>, /// The total number of points in the `Graph`. population: usize, /// The number of vertices in the `Graph`. @@ -31,7 +30,40 @@ pub struct Graph<'a, I, U: Number, D: Dataset, S: Cluster> { diameter: usize, } -impl<'a, I, U: Number, D: Dataset, S: Cluster> Graph<'a, I, U, D, S> { +impl<'a, T: Number, S: Cluster> Graph<'a, T, S> { + /// Create a new `Graph` from a root `Vertex` in a tree using a uniform + /// depth from the tree. + /// + /// # Arguments + /// + /// * `root`: The root `Vertex` of the tree from which to create the `Graph`. + /// * `depth`: The uniform depth at which to consider a `Vertex`. + /// * `min_depth`: The minimum depth at which to consider a `Vertex`. + pub fn from_root_uniform_depth, M: Metric>( + root: &'a Vertex, + data: &D, + metric: &M, + depth: usize, + min_depth: usize, + ) -> Self + where + T: 'a, + { + let cluster_scorer = |clusters: &[&'a Vertex]| { + clusters + .iter() + .map(|c| { + if c.depth() == depth || (c.is_leaf() && c.depth() < depth) { + 1.0 + } else { + 0.0 + } + }) + .collect::>() + }; + Self::from_root(root, data, metric, cluster_scorer, min_depth) + } + /// Create a new `Graph` from a root `Vertex` in a tree. /// /// # Arguments @@ -39,14 +71,15 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Graph<'a, I, U, D, /// * `root`: The root `Vertex` of the tree from which to create the `Graph`. /// * `cluster_scorer`: A function that scores `Vertex`es. /// * `min_depth`: The minimum depth at which to consider a `Vertex`. - pub fn from_root( - root: &'a Vertex, + pub fn from_root, M: Metric>( + root: &'a Vertex, data: &D, - cluster_scorer: impl Fn(&[&'a Vertex]) -> Vec, + metric: &M, + cluster_scorer: impl Fn(&[&'a Vertex]) -> Vec, min_depth: usize, ) -> Self where - U: 'a, + T: 'a, { let clusters = root.subtree(); let scores = cluster_scorer(&clusters); @@ -56,10 +89,10 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Graph<'a, I, U, D, // `Vertex`es are selected by highest score and then by shallowest depth. let mut candidates = clusters .into_iter() - .zip(scores.into_iter().map(OrderedFloat)) + .zip(scores) .filter(|(c, _)| c.is_leaf() || c.depth() >= min_depth) .map(|(c, s)| (s, Reverse(c))) - .collect::>(); + .collect::>(); let mut clusters = vec![]; while let Some((_, Reverse(v))) = candidates.pop() { @@ -69,12 +102,16 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Graph<'a, I, U, D, candidates.retain(|&(_, Reverse(other))| !(v.is_descendant_of(other) || other.is_descendant_of(v))); } - Self::from_vertices(&clusters, data) + Self::from_vertices(&clusters, data, metric) } /// Create a new `Graph` from a collection of `Vertex`es. - pub fn from_vertices(vertices: &[&'a Vertex], data: &D) -> Self { - let components = adjacency_list::AdjacencyList::new(vertices, data) + pub fn from_vertices, M: Metric>( + vertices: &[&'a Vertex], + data: &D, + metric: &M, + ) -> Self { + let components = adjacency_list::AdjacencyList::new(vertices, data, metric) .into_iter() .map(Component::new) .collect::>(); @@ -82,6 +119,7 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Graph<'a, I, U, D, let population = vertices.iter().map(|v| v.cardinality()).sum(); let cardinality = components.iter().map(Component::cardinality).sum(); let diameter = components.iter().map(Component::diameter).max().unwrap_or_default(); + Self { components, population, @@ -109,17 +147,17 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Graph<'a, I, U, D, } /// Iterate over the `Vertex`es in the `Graph`. - pub fn iter_clusters(&self) -> impl Iterator> { + pub fn iter_clusters(&self) -> impl Iterator> { self.components.iter().flat_map(Component::iter_vertices) } /// Iterate over the edges in the `Graph`. - pub fn iter_edges(&self) -> impl Iterator, &Vertex, U)> + '_ { + pub fn iter_edges(&self) -> impl Iterator, &Vertex, T)> + '_ { self.components.iter().flat_map(Component::iter_edges) } /// Iterate over the lists of neighbors of the `Vertex`es in the `Graph`. - pub fn iter_neighbors(&self) -> impl Iterator, U>> + '_ { + pub fn iter_neighbors(&self) -> impl Iterator, T>> + '_ { self.components.iter().flat_map(Component::iter_neighbors) } @@ -134,7 +172,7 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Graph<'a, I, U, D, } /// Iterate over the `Component`s in the `Graph`. - pub fn iter_components(&self) -> impl Iterator> { + pub fn iter_components(&self) -> impl Iterator> { self.components.iter() } @@ -155,16 +193,42 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Graph<'a, I, U, D, } } -impl<'a, I: Send + Sync, U: Number, D: ParDataset, S: ParCluster> Graph<'a, I, U, D, S> { - /// Parallel version of `Graph::from_root`. - pub fn par_from_root( - root: &'a Vertex, +impl<'a, T: Number, S: ParCluster> Graph<'a, T, S> { + /// Parallel version of [`Graph::from_root_uniform_depth`](crate::chaoda::graph::Graph::from_root_uniform_depth). + pub fn par_from_root_uniform_depth, M: ParMetric>( + root: &'a Vertex, + data: &D, + metric: &M, + depth: usize, + min_depth: usize, + ) -> Self + where + T: 'a, + { + let cluster_scorer = |clusters: &[&'a Vertex]| { + clusters + .iter() + .map(|c| { + if c.depth() == depth || (c.is_leaf() && c.depth() < depth) { + 1.0 + } else { + 0.0 + } + }) + .collect::>() + }; + Self::par_from_root(root, data, metric, cluster_scorer, min_depth) + } + /// Parallel version of [`Graph::from_root`](crate::chaoda::graph::Graph::from_root). + pub fn par_from_root, M: ParMetric>( + root: &'a Vertex, data: &D, - cluster_scorer: impl Fn(&[&'a Vertex]) -> Vec, + metric: &M, + cluster_scorer: impl Fn(&[&'a Vertex]) -> Vec, min_depth: usize, ) -> Self where - U: 'a, + T: 'a, { let clusters = root.subtree(); let scores = cluster_scorer(&clusters); @@ -174,10 +238,10 @@ impl<'a, I: Send + Sync, U: Number, D: ParDataset, S: ParCluster> // `Vertex`es are selected by highest score and then by shallowest depth. let mut candidates = clusters .into_iter() - .zip(scores.into_iter().map(OrderedFloat)) + .zip(scores) .filter(|(c, _)| c.is_leaf() || c.depth() >= min_depth) .map(|(c, s)| (s, Reverse(c))) - .collect::>(); + .collect::>(); let mut clusters = vec![]; while let Some((_, Reverse(v))) = candidates.pop() { @@ -187,12 +251,16 @@ impl<'a, I: Send + Sync, U: Number, D: ParDataset, S: ParCluster> candidates.retain(|&(_, Reverse(other))| !(v.is_descendant_of(other) || other.is_descendant_of(v))); } - Self::par_from_vertices(&clusters, data) + Self::par_from_vertices(&clusters, data, metric) } - /// Create a new `Graph` from a collection of `Vertex`es. - pub fn par_from_vertices(vertices: &[&'a Vertex], data: &D) -> Self { - let components = adjacency_list::AdjacencyList::par_new(vertices, data) + /// Parallel version of [`Graph::from_vertices`](crate::chaoda::graph::Graph::from_vertices). + pub fn par_from_vertices, M: ParMetric>( + vertices: &[&'a Vertex], + data: &D, + metric: &M, + ) -> Self { + let components = adjacency_list::AdjacencyList::par_new(vertices, data, metric) .into_par_iter() .map(Component::new) .collect::>(); @@ -200,6 +268,7 @@ impl<'a, I: Send + Sync, U: Number, D: ParDataset, S: ParCluster> let population = vertices.iter().map(|v| v.cardinality()).sum(); let cardinality = components.iter().map(Component::cardinality).sum(); let diameter = components.iter().map(Component::diameter).max().unwrap_or_default(); + Self { components, population, @@ -208,7 +277,7 @@ impl<'a, I: Send + Sync, U: Number, D: ParDataset, S: ParCluster> } } - /// Compute the stationary probability of each `Vertex` in the `Graph`. + /// Parallel version of [`Graph::iter_clusters`](crate::chaoda::graph::Graph::iter_clusters). #[must_use] pub fn par_compute_stationary_probabilities(&self, log2_num_steps: usize) -> Vec { self.components diff --git a/crates/abd-clam/src/chaoda/graph/node.rs b/crates/abd-clam/src/chaoda/graph/node.rs index ebad9a953..5e1a4fce4 100644 --- a/crates/abd-clam/src/chaoda/graph/node.rs +++ b/crates/abd-clam/src/chaoda/graph/node.rs @@ -2,13 +2,11 @@ //! some additional information for the `Graph`-based anomaly detection //! algorithms. -use core::hash::{Hash, Hasher}; - use std::collections::HashSet; use distances::Number; -use crate::{chaoda::Vertex, Cluster, Dataset}; +use crate::{chaoda::Vertex, Cluster}; use super::adjacency_list::AdjacencyList; @@ -16,15 +14,15 @@ use super::adjacency_list::AdjacencyList; /// some additional information for the `Graph`-based anomaly detection /// algorithms. #[derive(Clone)] -pub struct Node<'a, I, U: Number, D: Dataset, S: Cluster> { +pub struct Node<'a, T: Number, S: Cluster> { /// The `Vertex` that the `Node` represents. - vertex: &'a Vertex, + vertex: &'a Vertex, /// The cumulative size of the `Graph` neighborhood of the `Node` at each /// step of a breadth-first traversal. neighborhood_sizes: Vec, } -impl<'a, I, U: Number, D: Dataset, S: Cluster> Node<'a, I, U, D, S> { +impl<'a, T: Number, S: Cluster> Node<'a, T, S> { /// Create a new `Node` from a `Vertex` and an `AdjacencyList`. /// /// # Arguments @@ -32,7 +30,7 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Node<'a, I, U, D, /// * `vertex`: The `Vertex` that the `Node` represents. /// * `adjacency_list`: The `AdjacencyList` of the `Component` that the /// `Vertex` belongs to. - pub fn new(vertex: &'a Vertex, adjacency_list: &AdjacencyList>) -> Self { + pub fn new(vertex: &'a Vertex, adjacency_list: &AdjacencyList>) -> Self { let neighborhood_sizes = Self::compute_neighborhood_sizes(vertex, adjacency_list); Self { @@ -44,8 +42,8 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Node<'a, I, U, D, /// Get the cumulative size of the `Graph` neighborhood of the `Node` at /// each step of a breadth-first traversal. fn compute_neighborhood_sizes( - vertex: &'a Vertex, - adjacency_list: &AdjacencyList>, + vertex: &'a Vertex, + adjacency_list: &AdjacencyList>, ) -> Vec { let mut frontier_sizes = vec![]; let mut visited = HashSet::new(); @@ -100,16 +98,16 @@ impl<'a, I, U: Number, D: Dataset, S: Cluster> Node<'a, I, U, D, } } -impl<'a, I, U: Number, D: Dataset, S: Cluster> Eq for Node<'a, I, U, D, S> {} +impl> Eq for Node<'_, T, S> {} -impl<'a, I, U: Number, D: Dataset, S: Cluster> PartialEq for Node<'a, I, U, D, S> { +impl> PartialEq for Node<'_, T, S> { fn eq(&self, other: &Self) -> bool { self.vertex == other.vertex } } -impl<'a, I, U: Number, D: Dataset, S: Cluster> Hash for Node<'a, I, U, D, S> { - fn hash(&self, state: &mut H) { +impl> core::hash::Hash for Node<'_, T, S> { + fn hash(&self, state: &mut H) { self.vertex.hash(state); } } diff --git a/crates/abd-clam/src/chaoda/inference/combination.rs b/crates/abd-clam/src/chaoda/inference/combination.rs index 7099fdb9b..511455014 100644 --- a/crates/abd-clam/src/chaoda/inference/combination.rs +++ b/crates/abd-clam/src/chaoda/inference/combination.rs @@ -1,19 +1,19 @@ //! Utilities for handling a pair of `MetaMLModel` and `GraphAlgorithm`. use distances::Number; -use serde::{Deserialize, Serialize}; use crate::{ chaoda::{training::GraphEvaluator, Graph, GraphAlgorithm, Vertex}, cluster::ParCluster, dataset::ParDataset, - Cluster, Dataset, + metric::ParMetric, + Cluster, Dataset, Metric, }; use super::TrainedMetaMlModel; /// A combination of `TrainedMetaMLModel` and `GraphAlgorithm`. -#[derive(Serialize, Deserialize)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] #[allow(clippy::module_name_repetitions)] pub struct TrainedCombination { /// The `MetaMLModel` to use. @@ -72,11 +72,10 @@ impl TrainedCombination { } /// Get the meta-ML scorer function in a callable for any number of `Vertex`es. - pub fn meta_ml_scorer(&self) -> impl Fn(&[&Vertex]) -> Vec + '_ + pub fn meta_ml_scorer(&self) -> impl Fn(&[&Vertex]) -> Vec + '_ where - U: Number, - D: Dataset, - S: Cluster, + T: Number, + S: Cluster, { move |clusters| { let props = clusters.iter().flat_map(|c| c.ratios()).collect::>(); @@ -86,19 +85,21 @@ impl TrainedCombination { /// Create a `Graph` from the `root` with the given `data` and `min_depth` /// using the `TrainedMetaMLModel`. - pub fn create_graph<'a, I, U, D, S>( + pub fn create_graph<'a, I, T, D, M, S>( &self, - root: &'a Vertex, + root: &'a Vertex, data: &D, + metric: &M, min_depth: usize, - ) -> Graph<'a, I, U, D, S> + ) -> Graph<'a, T, S> where - U: Number, - D: Dataset, - S: Cluster, + T: Number, + D: Dataset, + M: Metric, + S: Cluster, { let cluster_scorer = self.meta_ml_scorer(); - Graph::from_root(root, data, cluster_scorer, min_depth) + Graph::from_root(root, data, metric, cluster_scorer, min_depth) } /// Predict the anomaly scores of the points in the `data`. @@ -114,20 +115,22 @@ impl TrainedCombination { /// A tuple of: /// * The `Graph` constructed from the `root`. /// * The anomaly scores of the points in the `data`. - pub fn predict<'a, I, U, D, S>( + pub fn predict<'a, I, T, D, M, S>( &self, - root: &'a Vertex, + root: &'a Vertex, data: &D, + metric: &M, min_depth: usize, - ) -> (Graph<'a, I, U, D, S>, Vec) + ) -> (Graph<'a, T, S>, Vec) where - U: Number, - D: Dataset, - S: Cluster, + T: Number, + D: Dataset, + M: Metric, + S: Cluster, { ftlog::debug!("Predicting with {}...", self.name()); - let graph = self.create_graph(root, data, min_depth); + let graph = self.create_graph(root, data, metric, min_depth); let scores = self.graph_algorithm.evaluate_points(&graph); let scores = if self.invert_scores() { @@ -139,39 +142,43 @@ impl TrainedCombination { (graph, scores) } - /// Parallel version of `create_graph`. - pub fn par_create_graph<'a, I, U, D, S>( + /// Parallel version of [`TrainingCombination::create_graph`](crate::chaoda::inference::combination::TrainedCombination::create_graph). + pub fn par_create_graph<'a, I, T, D, M, S>( &self, - root: &'a Vertex, + root: &'a Vertex, data: &D, + metric: &M, min_depth: usize, - ) -> Graph<'a, I, U, D, S> + ) -> Graph<'a, T, S> where I: Send + Sync, - U: Number, - D: ParDataset, - S: ParCluster, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParCluster, { let cluster_scorer = self.meta_ml_scorer(); - Graph::par_from_root(root, data, cluster_scorer, min_depth) + Graph::par_from_root(root, data, metric, cluster_scorer, min_depth) } - /// Parallel version of `predict`. - pub fn par_predict<'a, I, U, D, S>( + /// Parallel version of [`TrainingCombination::predict`](crate::chaoda::inference::combination::TrainedCombination::predict). + pub fn par_predict<'a, I, T, D, M, S>( &self, - root: &'a Vertex, + root: &'a Vertex, data: &D, + metric: &M, min_depth: usize, - ) -> (Graph<'a, I, U, D, S>, Vec) + ) -> (Graph<'a, T, S>, Vec) where I: Send + Sync, - U: Number, - D: ParDataset, - S: ParCluster, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParCluster, { ftlog::debug!("Predicting with {}...", self.name()); - let graph = self.par_create_graph(root, data, min_depth); + let graph = self.par_create_graph(root, data, metric, min_depth); let scores = self.graph_algorithm.evaluate_points(&graph); let scores = if self.invert_scores() { diff --git a/crates/abd-clam/src/chaoda/inference/meta_ml.rs b/crates/abd-clam/src/chaoda/inference/meta_ml.rs index 0ffffa31c..14846090a 100644 --- a/crates/abd-clam/src/chaoda/inference/meta_ml.rs +++ b/crates/abd-clam/src/chaoda/inference/meta_ml.rs @@ -1,6 +1,5 @@ //! Inferring with meta-ml models. -use serde::{Deserialize, Serialize}; use smartcore::{ linalg::basic::matrix::DenseMatrix, linear::linear_regression::LinearRegression, tree::decision_tree_regressor::DecisionTreeRegressor, @@ -9,7 +8,7 @@ use smartcore::{ use crate::chaoda::NUM_RATIOS; /// A trained meta-ml model. -#[derive(Serialize, Deserialize)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] pub enum TrainedMetaMlModel { /// A linear regression model. LinearRegression(LinearRegression, Vec>), diff --git a/crates/abd-clam/src/chaoda/inference/mod.rs b/crates/abd-clam/src/chaoda/inference/mod.rs index 3f4d1668d..6a2ef4281 100644 --- a/crates/abd-clam/src/chaoda/inference/mod.rs +++ b/crates/abd-clam/src/chaoda/inference/mod.rs @@ -1,67 +1,66 @@ //! Utilities for running inference with pre-trained Chaoda models. -mod combination; -mod meta_ml; - use distances::Number; use ndarray::prelude::*; use rayon::prelude::*; -use serde::{Deserialize, Serialize}; use crate::{ - adapter::{Adapter, ParAdapter}, - cluster::ParCluster, + cluster::{ + adapter::{Adapter, ParAdapter}, + ParCluster, ParPartition, Partition, + }, dataset::ParDataset, - partition::ParPartition, - Dataset, Metric, Partition, + metric::ParMetric, + Dataset, }; +use super::{roc_auc_score, Vertex}; + +mod combination; +mod meta_ml; +mod trained_smc; + pub use combination::TrainedCombination; pub use meta_ml::TrainedMetaMlModel; - -use super::{roc_auc_score, Vertex}; +pub use trained_smc::TrainedSmc; /// A pre-trained Chaoda model. -#[derive(Serialize, Deserialize)] -pub struct Chaoda { +pub struct Chaoda { /// The distance metrics to train with. - #[serde(with = "serde_arrays")] - metrics: [Metric; M], + metrics: [Box>; M], /// The trained models. - #[serde(with = "serde_arrays")] combinations: [Vec; M], } -impl Chaoda { +impl Chaoda { /// Create a new Chaoda model with the given metrics and trained combinations. #[must_use] - pub const fn new(metrics: [Metric; M], combinations: [Vec; M]) -> Self { + pub const fn new(metrics: [Box>; M], combinations: [Vec; M]) -> Self { Self { metrics, combinations } } /// Get the distance metrics used by the model. #[must_use] - pub const fn metrics(&self) -> &[Metric; M] { + pub const fn metrics(&self) -> &[Box>; M] { &self.metrics } /// Set the distance metrics to be used by the model. - pub fn set_metrics(&mut self, metrics: [Metric; M]) { + pub fn set_metrics(&mut self, metrics: [Box>; M]) { self.metrics = metrics; } /// Create trees to use for inference, one for each metric. - pub fn create_trees, S: Partition, C: Fn(&S) -> bool>( + pub fn create_trees, S: Partition, C: Fn(&S) -> bool>( &self, - data: &mut D, + data: &D, criteria: &[C; M], seed: Option, - ) -> [Vertex; M] { + ) -> [Vertex; M] { let mut trees = Vec::new(); for (metric, criteria) in self.metrics.iter().zip(criteria.iter()) { - data.set_metric(metric.clone()); - let source = S::new_tree(data, criteria, seed); - let tree = Vertex::adapt_tree(source, None); + let source = S::new_tree(data, metric, criteria, seed); + let tree = Vertex::adapt_tree(source, None, data, metric); trees.push(tree); } trees @@ -70,10 +69,10 @@ impl Chaoda { } /// Run inference on the given data. - pub fn predict_from_trees, S: Partition>( + pub fn predict_from_trees, S: Partition>( &self, - data: &mut D, - trees: &[Vertex; M], + data: &D, + trees: &[Vertex; M], min_depth: usize, ) -> Vec { // TODO: Make this a parameter. @@ -82,11 +81,10 @@ impl Chaoda { let mut num_discerning = 0; let mut scores = Vec::new(); for ((metric, root), combinations) in self.metrics.iter().zip(trees.iter()).zip(self.combinations.iter()) { - data.set_metric(metric.clone()); for c in combinations { if c.discerns(tol) { num_discerning += 1; - let (_, row) = c.predict(root, data, min_depth); + let (_, row) = c.predict(root, data, metric, min_depth); scores.extend_from_slice(&row); } } @@ -107,9 +105,9 @@ impl Chaoda { } /// Run inference on the given data. - pub fn predict, S: Partition, C: Fn(&S) -> bool>( + pub fn predict, S: Partition, C: Fn(&S) -> bool>( &self, - data: &mut D, + data: &D, criteria: &[C; M], seed: Option, min_depth: usize, @@ -119,9 +117,9 @@ impl Chaoda { } /// Evaluate the model on the given data. - pub fn evaluate, S: Partition, C: Fn(&S) -> bool>( + pub fn evaluate, S: Partition, C: Fn(&S) -> bool>( &self, - data: &mut D, + data: &D, criteria: &[C; M], labels: &[bool], seed: Option, @@ -133,19 +131,18 @@ impl Chaoda { } } -impl Chaoda { - /// Parallel version of `create_trees`. - pub fn par_create_trees, S: ParPartition, C: (Fn(&S) -> bool) + Send + Sync>( +impl Chaoda { + /// Parallel version of [`Chaoda::create_trees`](crate::chaoda::Chaoda::create_trees). + pub fn par_create_trees, S: ParPartition, C: (Fn(&S) -> bool) + Send + Sync>( &self, - data: &mut D, + data: &D, criteria: &[C; M], seed: Option, - ) -> [Vertex; M] { + ) -> [Vertex; M] { let mut trees = Vec::new(); for (metric, criteria) in self.metrics.iter().zip(criteria.iter()) { - data.set_metric(metric.clone()); - let source = S::par_new_tree(data, criteria, seed); - let tree = Vertex::par_adapt_tree(source, None); + let source = S::par_new_tree(data, metric, criteria, seed); + let tree = Vertex::par_adapt_tree(source, None, data, metric); trees.push(tree); } trees @@ -153,11 +150,11 @@ impl Chaoda { .unwrap_or_else(|_| unreachable!("Could not convert Vec> to [Vertex; {M}]")) } - /// Run inference on the given data. - pub fn par_predict_from_trees, S: ParCluster>( + /// Parallel version of [`Chaoda::predict_from_trees`](crate::chaoda::Chaoda::predict_from_trees). + pub fn par_predict_from_trees, S: ParCluster>( &self, - data: &mut D, - trees: &[Vertex; M], + data: &D, + trees: &[Vertex; M], min_depth: usize, ) -> Vec { // TODO: Make this a parameter. @@ -167,12 +164,11 @@ impl Chaoda { let mut scores = Vec::new(); for ((metric, root), combinations) in self.metrics.iter().zip(trees.iter()).zip(self.combinations.iter()) { - data.set_metric(metric.clone()); let new_scores = combinations .par_iter() .filter_map(|c| { if c.discerns(tol) { - let (_, row) = c.predict(root, data, min_depth); + let (_, row) = c.predict(root, data, metric, min_depth); Some(row) } else { None @@ -198,10 +194,10 @@ impl Chaoda { .to_vec() } - /// Parallel version of `predict`. - pub fn par_predict, S: ParPartition, C: (Fn(&S) -> bool) + Send + Sync>( + /// Parallel version of [`Chaoda::predict`](crate::chaoda::Chaoda::predict). + pub fn par_predict, S: ParPartition, C: (Fn(&S) -> bool) + Send + Sync>( &self, - data: &mut D, + data: &D, criteria: &[C; M], seed: Option, min_depth: usize, @@ -210,8 +206,8 @@ impl Chaoda { self.par_predict_from_trees(data, &trees, min_depth) } - /// Parallel version of `evaluate`. - pub fn par_evaluate, S: ParPartition, C: (Fn(&S) -> bool) + Send + Sync>( + /// Parallel version of [`Chaoda::evaluate`](crate::chaoda::Chaoda::evaluate). + pub fn par_evaluate, S: ParPartition, C: (Fn(&S) -> bool) + Send + Sync>( &self, data: &mut D, criteria: &[C; M], diff --git a/crates/abd-clam/src/chaoda/inference/trained_smc.rs b/crates/abd-clam/src/chaoda/inference/trained_smc.rs new file mode 100644 index 000000000..eda882760 --- /dev/null +++ b/crates/abd-clam/src/chaoda/inference/trained_smc.rs @@ -0,0 +1,369 @@ +//! A trained Single-Metric-CHAODA ensemble. + +use distances::Number; +use ndarray::prelude::*; +use rayon::prelude::*; + +use crate::{ + chaoda::{roc_auc_score, Vertex}, + cluster::{ + adapter::{Adapter, ParAdapter}, + ParCluster, ParPartition, Partition, + }, + dataset::ParDataset, + metric::ParMetric, + Cluster, Dataset, Metric, +}; + +use super::TrainedCombination; + +/// A trained Single-Metric-CHAODA ensemble. +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] +pub struct TrainedSmc(Vec); + +impl TrainedSmc { + /// Create a new trained Single-Metric-CHAODA ensemble. + /// + /// # Arguments + /// + /// - `combinations`: The trained combinations to use. + /// + /// # Returns + /// + /// The trained Single-Metric-CHAODA ensemble. + #[must_use] + pub const fn new(combinations: Vec) -> Self { + Self(combinations) + } + + /// Get the trained combinations. + #[must_use] + pub fn combinations(&self) -> &[TrainedCombination] { + &self.0 + } + + /// Create a tree for inference. + /// + /// # Arguments + /// + /// - `data`: The dataset to use. + /// - `metric`: The metric to use. + /// - `criteria`: The criteria to use for partitioning. + /// - `seed`: The seed to use for random number generation. + /// + /// # Type Parameters + /// + /// - `I`: The type of the dataset items. + /// - `T`: The type of the metric values. + /// - `D`: The type of the dataset. + /// - `M`: The type of the metric. + /// - `S`: The type of the cluster that will be adapted to `Vertex`. + /// - `C`: The type of the criteria function. + fn create_tree(data: &D, metric: &M, criteria: &C, seed: Option) -> Vertex + where + T: Number, + D: Dataset, + M: Metric, + S: Partition, + C: Fn(&S) -> bool, + { + let source = S::new_tree(data, metric, criteria, seed); + Vertex::adapt_tree(source, None, data, metric) + } + + /// Run inference on the given data using the pre-built tree. + /// + /// # Arguments + /// + /// - `data`: The dataset to use. + /// - `metric`: The metric to use. + /// - `root`: The root of the tree to use. + /// - `min_depth`: The minimum depth to consider for selecting clusters to + /// create graphs. + /// - `tol`: The tolerance to use for discerning meta-ml models. This should + /// be a small positive number, ideally less than `0.1`. + /// + /// # Returns + /// + /// The predicted anomaly scores for each item in the dataset. + /// + /// # Type Parameters + /// + /// - `I`: The type of the dataset items. + /// - `T`: The type of the metric values. + /// - `D`: The type of the dataset. + /// - `M`: The type of the metric. + /// - `S`: The type of the cluster that was adapted to `Vertex`. + pub fn predict_from_tree( + &self, + data: &D, + metric: &M, + root: &Vertex, + min_depth: usize, + tol: f32, + ) -> Vec + where + T: Number, + D: Dataset, + M: Metric, + S: Cluster, + { + let (num_discerning, scores) = self + .0 + .iter() + .enumerate() + .filter(|(_, combination)| combination.discerns(tol)) + .fold((0, Vec::new()), |(num_discerning, mut scores), (i, combination)| { + ftlog::info!( + "Predicting with combination {}/{} {}", + i + 1, + self.0.len(), + combination.name() + ); + let (_, mut row) = combination.predict(root, data, metric, min_depth); + scores.append(&mut row); + (num_discerning + 1, scores) + }); + + if num_discerning == 0 { + ftlog::warn!("No discerning combinations found. Returning all scores as `0.5`."); + return vec![0.5; data.cardinality()]; + }; + + ftlog::info!("Averaging scores from {num_discerning} discerning combinations."); + let shape = (data.cardinality(), num_discerning); + let scores_len = scores.len(); + let scores = Array2::from_shape_vec(shape, scores).unwrap_or_else(|e| { + unreachable!( + "Could not convert Vec of len {scores_len} to Array2 of shape {:?}: {e}", + shape + ) + }); + + scores + .mean_axis(Axis(1)) + .unwrap_or_else(|| unreachable!("Could not compute mean of Array2 along axis 1")) + .to_vec() + } + + /// Run inference on the given data. + /// + /// # Arguments + /// + /// - `data`: The dataset to use. + /// - `metric`: The metric to use. + /// - `criteria`: The criteria to use for partitioning. + /// - `seed`: The seed to use for random number generation. + /// - `min_depth`: The minimum depth to consider for selecting clusters to + /// create graphs. + /// - `tol`: The tolerance to use for discerning meta-ml models. This should + /// be a small positive number, ideally less than `0.1`. + /// + /// # Returns + /// + /// The predicted anomaly scores for each item in the dataset. + /// + /// # Type Parameters + /// + /// - `I`: The type of the dataset items. + /// - `T`: The type of the metric values. + /// - `D`: The type of the dataset. + /// - `M`: The type of the metric. + /// - `S`: The type of the cluster that will be adapted to `Vertex`. + /// - `C`: The type of the criteria function. + pub fn predict( + &self, + data: &D, + metric: &M, + criteria: &C, + seed: Option, + min_depth: usize, + tol: f32, + ) -> Vec + where + T: Number, + D: Dataset, + M: Metric, + S: Partition, + C: Fn(&S) -> bool, + { + let root = Self::create_tree(data, metric, criteria, seed); + self.predict_from_tree(data, metric, &root, min_depth, tol) + } + + /// Evaluate the model on the given data. + /// + /// # Arguments + /// + /// - `data`: The dataset to use. + /// - `labels`: The labels to use for evaluation. + /// - `metric`: The metric to use. + /// - `criteria`: The criteria to use for partitioning. + /// - `seed`: The seed to use for random number generation. + /// - `min_depth`: The minimum depth to consider for selecting clusters to + /// create graphs. + /// - `tol`: The tolerance to use for discerning meta-ml models. This should + /// be a small positive number, ideally less than `0.1`. + /// + /// # Returns + /// + /// The predicted anomaly scores for each item in the dataset. + /// + /// # Type Parameters + /// + /// - `I`: The type of the dataset items. + /// - `T`: The type of the metric values. + /// - `D`: The type of the dataset. + /// - `M`: The type of the metric. + /// - `S`: The type of the cluster that will be adapted to `Vertex`. + /// - `C`: The type of the criteria function. + #[allow(clippy::too_many_arguments)] + pub fn evaluate( + &self, + data: &D, + labels: &[bool], + metric: &M, + criteria: &C, + seed: Option, + min_depth: usize, + tol: f32, + ) -> f32 + where + T: Number, + D: Dataset, + M: Metric, + S: Partition, + C: Fn(&S) -> bool, + { + let root = Self::create_tree(data, metric, criteria, seed); + let scores = self.predict_from_tree(data, metric, &root, min_depth, tol); + roc_auc_score(labels, &scores) + .unwrap_or_else(|e| unreachable!("Could not compute ROC-AUC score for dataset {}: {e}", data.name())) + } + + /// Parallel version of [`TrainedSmc::create_tree`](crate::chaoda::inference::TrainedSmc::create_tree). + fn par_create_tree(data: &D, metric: &M, criteria: &C, seed: Option) -> Vertex + where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParPartition, + C: (Fn(&S) -> bool) + Send + Sync, + { + let source = S::par_new_tree(data, metric, criteria, seed); + Vertex::par_adapt_tree(source, None, data, metric) + } + + /// Parallel version of [`TrainedSmc::predict_from_tree`](crate::chaoda::inference::TrainedSmc::predict_from_tree). + pub fn par_predict_from_tree( + &self, + data: &D, + metric: &M, + root: &Vertex, + min_depth: usize, + tol: f32, + ) -> Vec + where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParCluster, + { + let (num_discerning, scores) = self + .0 + .par_iter() + .enumerate() + .filter(|(_, combination)| combination.discerns(tol)) + .fold( + || (0, Vec::new()), + |(num_discerning, mut scores), (i, combination)| { + ftlog::info!( + "Predicting with combination {}/{} {}", + i + 1, + self.0.len(), + combination.name() + ); + let (_, mut row) = combination.par_predict(root, data, metric, min_depth); + scores.append(&mut row); + (num_discerning + 1, scores) + }, + ) + .reduce( + || (0, Vec::new()), + |(num_discerning, mut scores), (n, s)| { + scores.extend(s); + (num_discerning + n, scores) + }, + ); + + if num_discerning == 0 { + ftlog::warn!("No discerning combinations found. Returning all scores as `0.5`."); + return vec![0.5; data.cardinality()]; + }; + + ftlog::info!("Averaging scores from {num_discerning} discerning combinations."); + let shape = (data.cardinality(), num_discerning); + let scores_len = scores.len(); + let scores = Array2::from_shape_vec(shape, scores).unwrap_or_else(|e| { + unreachable!( + "Could not convert Vec of len {scores_len} to Array2 of shape {:?}: {e}", + shape + ) + }); + + scores + .mean_axis(Axis(1)) + .unwrap_or_else(|| unreachable!("Could not compute mean of Array2 along axis 1")) + .to_vec() + } + + /// Parallel version of [`TrainedSmc::predict`](crate::chaoda::inference::TrainedSmc::predict). + pub fn par_predict( + &self, + data: &D, + metric: &M, + criteria: &C, + seed: Option, + min_depth: usize, + tol: f32, + ) -> Vec + where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParPartition, + C: (Fn(&S) -> bool) + Send + Sync, + { + let root = Self::par_create_tree(data, metric, criteria, seed); + self.par_predict_from_tree(data, metric, &root, min_depth, tol) + } + + /// Parallel version of [`TrainedSmc::evaluate`](crate::chaoda::inference::TrainedSmc::evaluate). + #[allow(clippy::too_many_arguments)] + pub fn par_evaluate( + &self, + data: &D, + labels: &[bool], + metric: &M, + criteria: &C, + seed: Option, + min_depth: usize, + tol: f32, + ) -> f32 + where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParPartition, + C: (Fn(&S) -> bool) + Send + Sync, + { + let root = Self::par_create_tree(data, metric, criteria, seed); + let scores = self.par_predict_from_tree(data, metric, &root, min_depth, tol); + roc_auc_score(labels, &scores) + .unwrap_or_else(|e| unreachable!("Could not compute ROC-AUC score for dataset {}: {e}", data.name())) + } +} diff --git a/crates/abd-clam/src/chaoda/mod.rs b/crates/abd-clam/src/chaoda/mod.rs index 031db055d..c034045c7 100644 --- a/crates/abd-clam/src/chaoda/mod.rs +++ b/crates/abd-clam/src/chaoda/mod.rs @@ -8,9 +8,9 @@ mod training; pub use cluster::{Ratios, Vertex}; use distances::Number; pub use graph::Graph; -pub use inference::{Chaoda, TrainedMetaMlModel}; +pub use inference::{Chaoda, TrainedMetaMlModel, TrainedSmc}; #[allow(clippy::module_name_repetitions)] -pub use training::{ChaodaTrainer, GraphAlgorithm, TrainableMetaMlModel}; +pub use training::{ChaodaTrainer, GraphAlgorithm, TrainableMetaMlModel, TrainableSmc}; /// The number of anomaly ratios we use in CHAODA const NUM_RATIOS: usize = 6; @@ -42,21 +42,3 @@ pub fn roc_auc_score(y_true: &[bool], y_pred: &[f32]) -> Result { let y_pred = y_pred.to_vec(); Ok(smartcore::metrics::roc_auc_score(&y_true, &y_pred).as_f32()) } - -#[cfg(test)] -mod tests { - use distances::Number; - - #[test] - fn test_roc_auc_score() { - let y_score = (0..100).step_by(10).map(|s| s.as_f32() / 100.0).collect::>(); - - let y_true = y_score.iter().map(|&s| s > 0.5).collect::>(); - let auc = super::roc_auc_score(&y_true, &y_score).unwrap(); - assert_eq!(auc, 1.0); - - let y_true = y_true.into_iter().map(|t| !t).collect::>(); - let auc = super::roc_auc_score(&y_true, &y_score).unwrap(); - assert_eq!(auc, 0.0); - } -} diff --git a/crates/abd-clam/src/chaoda/training/algorithms/cc.rs b/crates/abd-clam/src/chaoda/training/algorithms/cc.rs index c7643dc8a..2f015b9d5 100644 --- a/crates/abd-clam/src/chaoda/training/algorithms/cc.rs +++ b/crates/abd-clam/src/chaoda/training/algorithms/cc.rs @@ -1,22 +1,22 @@ //! Cluster Cardinality algorithm. use distances::Number; -use serde::{Deserialize, Serialize}; -use crate::{chaoda::Graph, Cluster, Dataset}; +use crate::{chaoda::Graph, Cluster}; use super::GraphEvaluator; /// `Cluster`s with relatively few points are more likely to be anomalous. -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] pub struct ClusterCardinality; -impl, S: Cluster> GraphEvaluator for ClusterCardinality { +impl> GraphEvaluator for ClusterCardinality { fn name(&self) -> &str { "cc" } - fn evaluate_clusters(&self, g: &Graph) -> Vec { + fn evaluate_clusters(&self, g: &Graph) -> Vec { g.iter_clusters().map(|c| -c.cardinality().as_f32()).collect() } diff --git a/crates/abd-clam/src/chaoda/training/algorithms/gn.rs b/crates/abd-clam/src/chaoda/training/algorithms/gn.rs index 178582bc5..234268c95 100644 --- a/crates/abd-clam/src/chaoda/training/algorithms/gn.rs +++ b/crates/abd-clam/src/chaoda/training/algorithms/gn.rs @@ -1,14 +1,14 @@ //! Graph Neighborhood Algorithm use distances::Number; -use serde::{Deserialize, Serialize}; -use crate::{chaoda::Graph, Cluster, Dataset}; +use crate::{chaoda::Graph, Cluster}; use super::GraphEvaluator; /// `Cluster`s in an isolated neighborhood are more likely to be anomalous. -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] pub struct GraphNeighborhood { /// The fraction of graph diameter to use as the neighborhood radius. diameter_fraction: f32, @@ -29,12 +29,12 @@ impl GraphNeighborhood { } } -impl, S: Cluster> GraphEvaluator for GraphNeighborhood { +impl> GraphEvaluator for GraphNeighborhood { fn name(&self) -> &str { "gn" } - fn evaluate_clusters(&self, g: &Graph) -> Vec { + fn evaluate_clusters(&self, g: &Graph) -> Vec { let diameter = g.diameter(); #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] let k = (self.diameter_fraction * diameter.as_f32()).round() as usize; diff --git a/crates/abd-clam/src/chaoda/training/algorithms/mod.rs b/crates/abd-clam/src/chaoda/training/algorithms/mod.rs index 423ea232a..b72a296c2 100644 --- a/crates/abd-clam/src/chaoda/training/algorithms/mod.rs +++ b/crates/abd-clam/src/chaoda/training/algorithms/mod.rs @@ -8,12 +8,12 @@ mod sp; mod vd; use distances::Number; -use serde::{Deserialize, Serialize}; -use crate::{chaoda::Graph, utils, Cluster, Dataset}; +use crate::{chaoda::Graph, utils, Cluster}; /// The algorithms that make up the CHAODA ensemble. -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] pub enum GraphAlgorithm { /// The Cluster Cardinality algorithm. CC(cc::ClusterCardinality), @@ -73,25 +73,25 @@ impl GraphAlgorithm { Self::GN(_) => "GN", Self::PC(_) => "PC", Self::SC(_) => "SC", - Self::SP(_) => "SQ", + Self::SP(_) => "SP", Self::VD(_) => "VD", } } } -impl, S: Cluster> GraphEvaluator for GraphAlgorithm { +impl> GraphEvaluator for GraphAlgorithm { fn name(&self) -> &str { match self { - Self::CC(a) => >::name(a), - Self::GN(a) => >::name(a), - Self::PC(a) => >::name(a), - Self::SC(a) => >::name(a), - Self::SP(a) => >::name(a), - Self::VD(a) => >::name(a), + Self::CC(a) => >::name(a), + Self::GN(a) => >::name(a), + Self::PC(a) => >::name(a), + Self::SC(a) => >::name(a), + Self::SP(a) => >::name(a), + Self::VD(a) => >::name(a), } } - fn evaluate_clusters(&self, g: &Graph) -> Vec { + fn evaluate_clusters(&self, g: &Graph) -> Vec { match self { Self::CC(a) => a.evaluate_clusters(g), Self::GN(a) => a.evaluate_clusters(g), @@ -104,18 +104,18 @@ impl, S: Cluster> GraphEvaluator bool { match self { - Self::CC(a) => >::normalize_by_cluster(a), - Self::GN(a) => >::normalize_by_cluster(a), - Self::PC(a) => >::normalize_by_cluster(a), - Self::SC(a) => >::normalize_by_cluster(a), - Self::SP(a) => >::normalize_by_cluster(a), - Self::VD(a) => >::normalize_by_cluster(a), + Self::CC(a) => >::normalize_by_cluster(a), + Self::GN(a) => >::normalize_by_cluster(a), + Self::PC(a) => >::normalize_by_cluster(a), + Self::SC(a) => >::normalize_by_cluster(a), + Self::SP(a) => >::normalize_by_cluster(a), + Self::VD(a) => >::normalize_by_cluster(a), } } } /// A trait for how a `Graph` should be evaluated into anomaly scores. -pub trait GraphEvaluator, S: Cluster> { +pub trait GraphEvaluator> { /// Get the name of the algorithm. fn name(&self) -> &str; @@ -125,13 +125,13 @@ pub trait GraphEvaluator, S: Cluster> { /// The output vector must be the same length as the number of `OddBall`s in /// the `Graph`, and the order of the scores must correspond to the order of the /// `OddBall`s in the `Graph`. - fn evaluate_clusters(&self, g: &Graph) -> Vec; + fn evaluate_clusters(&self, g: &Graph) -> Vec; /// Whether to normalize anomaly scores by cluster or by point. fn normalize_by_cluster(&self) -> bool; /// Have points inherit scores from `OddBall`s. - fn inherit_scores(&self, g: &Graph, scores: &[f32]) -> Vec { + fn inherit_scores(&self, g: &Graph, scores: &[f32]) -> Vec { let mut points_scores = vec![0.0; g.population()]; for (c, &s) in g.iter_clusters().zip(scores.iter()) { for i in c.indices() { @@ -151,7 +151,7 @@ pub trait GraphEvaluator, S: Cluster> { /// # Returns /// /// * A vector of anomaly scores for each point in the `Graph`. - fn evaluate_points(&self, g: &Graph) -> Vec { + fn evaluate_points(&self, g: &Graph) -> Vec { let cluster_scores = { let scores = self.evaluate_clusters(g); if self.normalize_by_cluster() { diff --git a/crates/abd-clam/src/chaoda/training/algorithms/pc.rs b/crates/abd-clam/src/chaoda/training/algorithms/pc.rs index b4f84e2d6..b582bca30 100644 --- a/crates/abd-clam/src/chaoda/training/algorithms/pc.rs +++ b/crates/abd-clam/src/chaoda/training/algorithms/pc.rs @@ -1,22 +1,23 @@ //! Relative Parent Cardinality algorithm. use distances::Number; -use serde::{Deserialize, Serialize}; -use crate::{chaoda::Graph, Cluster, Dataset}; +use crate::{chaoda::Graph, Cluster}; use super::GraphEvaluator; -/// `Cluster`s with a smaller fraction of points from their parent `Cluster` are more anomalous. -#[derive(Clone, Serialize, Deserialize)] +/// `Cluster`s with a smaller fraction of points from their parent `Cluster` are +/// more anomalous. +#[derive(Clone)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] pub struct ParentCardinality; -impl, S: Cluster> GraphEvaluator for ParentCardinality { +impl> GraphEvaluator for ParentCardinality { fn name(&self) -> &str { "pc" } - fn evaluate_clusters(&self, g: &Graph) -> Vec { + fn evaluate_clusters(&self, g: &Graph) -> Vec { g.iter_accumulated_cp_car_ratios().collect() } diff --git a/crates/abd-clam/src/chaoda/training/algorithms/sc.rs b/crates/abd-clam/src/chaoda/training/algorithms/sc.rs index d66ff9388..636209bf7 100644 --- a/crates/abd-clam/src/chaoda/training/algorithms/sc.rs +++ b/crates/abd-clam/src/chaoda/training/algorithms/sc.rs @@ -1,22 +1,23 @@ //! Subgraph Cardinality algorithm. use distances::Number; -use serde::{Deserialize, Serialize}; -use crate::{chaoda::Graph, Cluster, Dataset}; +use crate::{chaoda::Graph, Cluster}; use super::GraphEvaluator; -/// `Cluster`s in subgraphs with relatively small population are more likely to be anomalous. -#[derive(Clone, Serialize, Deserialize)] +/// `Cluster`s in subgraphs with relatively small population are more likely to +/// be anomalous. +#[derive(Clone)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] pub struct SubgraphCardinality; -impl, S: Cluster> GraphEvaluator for SubgraphCardinality { +impl> GraphEvaluator for SubgraphCardinality { fn name(&self) -> &str { "sc" } - fn evaluate_clusters(&self, g: &Graph) -> Vec { + fn evaluate_clusters(&self, g: &Graph) -> Vec { g.iter_components() .flat_map(|sg| { let p = -sg.population().as_f32(); diff --git a/crates/abd-clam/src/chaoda/training/algorithms/sp.rs b/crates/abd-clam/src/chaoda/training/algorithms/sp.rs index a56e734e7..26eee6b1b 100644 --- a/crates/abd-clam/src/chaoda/training/algorithms/sp.rs +++ b/crates/abd-clam/src/chaoda/training/algorithms/sp.rs @@ -1,14 +1,14 @@ //! Stationary Probabilities Algorithm. use distances::Number; -use serde::{Deserialize, Serialize}; -use crate::{chaoda::Graph, Cluster, Dataset}; +use crate::{chaoda::Graph, Cluster}; use super::GraphEvaluator; /// Clusters with smaller stationary probabilities are more anomalous. -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] pub struct StationaryProbability { /// The Random Walk will be simulated for 2^`num_steps` steps. num_steps: usize, @@ -25,12 +25,12 @@ impl StationaryProbability { } } -impl, S: Cluster> GraphEvaluator for StationaryProbability { +impl> GraphEvaluator for StationaryProbability { fn name(&self) -> &str { "sp" } - fn evaluate_clusters(&self, g: &Graph) -> Vec { + fn evaluate_clusters(&self, g: &Graph) -> Vec { g.compute_stationary_probabilities(self.num_steps) .into_iter() .map(|x| 1.0 - x) diff --git a/crates/abd-clam/src/chaoda/training/algorithms/vd.rs b/crates/abd-clam/src/chaoda/training/algorithms/vd.rs index 69d7669cd..437a7df75 100644 --- a/crates/abd-clam/src/chaoda/training/algorithms/vd.rs +++ b/crates/abd-clam/src/chaoda/training/algorithms/vd.rs @@ -1,22 +1,22 @@ //! Vertex Degree Algorithm use distances::Number; -use serde::{Deserialize, Serialize}; -use crate::{chaoda::Graph, Cluster, Dataset}; +use crate::{chaoda::Graph, Cluster}; use super::GraphEvaluator; /// `Cluster`s with relatively few neighbors are more likely to be anomalous. -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone)] +#[cfg_attr(feature = "disk-io", derive(serde::Serialize, serde::Deserialize))] pub struct VertexDegree; -impl, S: Cluster> GraphEvaluator for VertexDegree { +impl> GraphEvaluator for VertexDegree { fn name(&self) -> &str { "vd" } - fn evaluate_clusters(&self, g: &Graph) -> Vec { + fn evaluate_clusters(&self, g: &Graph) -> Vec { g.iter_neighbors().map(|n| -n.len().as_f32()).collect() } diff --git a/crates/abd-clam/src/chaoda/training/combination.rs b/crates/abd-clam/src/chaoda/training/combination.rs index a87e5a1bb..c1849190b 100644 --- a/crates/abd-clam/src/chaoda/training/combination.rs +++ b/crates/abd-clam/src/chaoda/training/combination.rs @@ -9,7 +9,7 @@ use crate::{ chaoda::{ inference, roc_auc_score, training::GraphEvaluator, Graph, GraphAlgorithm, TrainableMetaMlModel, NUM_RATIOS, }, - Cluster, Dataset, + Cluster, }; /// A trainable combination of a `MetaMLModel` and a `GraphAlgorithm`. @@ -69,15 +69,10 @@ impl TrainableCombination { /// # Errors /// /// - If any roc-auc score calculation fails. - pub fn data_from_graph( - &self, - graph: &Graph, - labels: &[bool], - ) -> Result<([Vec; 2], f32), String> + pub fn data_from_graph(&self, graph: &Graph, labels: &[bool]) -> Result<([Vec; 2], f32), String> where - U: Number, - D: Dataset, - S: Cluster, + T: Number, + S: Cluster, { let props = graph.iter_anomaly_properties().flatten().collect::>(); let predictions = self.graph_algorithm.evaluate_points(graph); @@ -86,15 +81,16 @@ impl TrainableCombination { .iter_clusters() .map(|c| { // Get the labels and predictions and append a dummy true and false value to avoid empty classes for roc_auc_score - let y_true = c - .indices() - .map(|i| labels[i]) + let indices = c.indices(); + let y_true = indices + .iter() + .map(|&i| labels[i]) .chain(std::iter::once(true)) .chain(std::iter::once(false)) .collect::>(); - let y_pred = c - .indices() - .map(|i| predictions[i]) + let y_pred = indices + .iter() + .map(|&i| predictions[i]) .chain(std::iter::once(1.0)) .chain(std::iter::once(0.0)) .collect::>(); diff --git a/crates/abd-clam/src/chaoda/training/mod.rs b/crates/abd-clam/src/chaoda/training/mod.rs index e8ec3282e..c8e3cbdba 100644 --- a/crates/abd-clam/src/chaoda/training/mod.rs +++ b/crates/abd-clam/src/chaoda/training/mod.rs @@ -4,23 +4,29 @@ use distances::Number; use rayon::prelude::*; use crate::{ - adapter::{Adapter, ParAdapter}, - chaoda::inference::TrainedCombination, - cluster::ParCluster, + cluster::{ + adapter::{Adapter, ParAdapter}, + ParCluster, ParPartition, Partition, + }, dataset::ParDataset, - partition::ParPartition, - Cluster, Dataset, Metric, Partition, + metric::ParMetric, + Cluster, Dataset, +}; + +use super::{ + inference::{Chaoda, TrainedCombination}, + Graph, Vertex, }; mod algorithms; mod combination; mod meta_ml; +mod trainable_smc; pub use algorithms::{GraphAlgorithm, GraphEvaluator}; pub use combination::TrainableCombination; pub use meta_ml::TrainableMetaMlModel; - -use super::{inference::Chaoda, Graph, Vertex}; +pub use trainable_smc::TrainableSmc; /// A trainer for Chaoda models. /// @@ -29,14 +35,14 @@ use super::{inference::Chaoda, Graph, Vertex}; /// - `I`: The type of the input data. /// - `U`: The type of the distance values. /// - `M`: The number of metrics to train with. -pub struct ChaodaTrainer { +pub struct ChaodaTrainer { /// The distance metrics to train with. - metrics: [Metric; M], + metrics: [Box>; M], /// The combinations of `MetaMLModel`s and `GraphAlgorithm`s to train with. combinations: [Vec; M], } -impl ChaodaTrainer { +impl ChaodaTrainer { /// Create a new `ChaodaTrainer` with the given metrics and all pairs of /// `MetaMLModel`s and `GraphAlgorithm`s. /// @@ -48,7 +54,7 @@ impl ChaodaTrainer { #[must_use] #[allow(clippy::needless_pass_by_value)] pub fn new_all_pairs( - metrics: [Metric; M], + metrics: [Box>; M], meta_ml_models: Vec, graph_algorithms: Vec, ) -> Self { @@ -79,7 +85,7 @@ impl ChaodaTrainer { /// - `combinations`: The combinations of `MetaMLModel`s and `GraphAlgorithm`s to train with. #[must_use] #[allow(clippy::needless_pass_by_value)] - pub fn new(metrics: [Metric; M], combinations: Vec) -> Self { + pub fn new(metrics: [Box>; M], combinations: Vec) -> Self { let combinations = metrics .iter() .map(|_| combinations.clone()) @@ -109,19 +115,18 @@ impl ChaodaTrainer { /// - `D`: The type of the datasets. /// - `S`: The type of the `Cluster` to use for creating the `Vertex` trees. /// - `C`: The type of the criteria to use for creating the trees. - pub fn create_trees, S: Partition, C: Fn(&S) -> bool>( + pub fn create_trees, S: Partition, C: Fn(&S) -> bool>( &self, - datasets: &mut [D; N], + datasets: &[D; N], criteria: &[[C; M]; N], seed: Option, - ) -> [[Vertex; M]; N] { + ) -> [[Vertex; M]; N] { let mut trees = Vec::new(); - for (data, criteria) in datasets.iter_mut().zip(criteria) { + for (data, criteria) in datasets.iter().zip(criteria) { let mut metric_trees = Vec::new(); for (metric, criteria) in self.metrics.iter().zip(criteria) { - data.set_metric(metric.clone()); - let root = S::new_tree(data, criteria, seed); - metric_trees.push(Vertex::adapt_tree(root, None)); + let root = S::new_tree(data, metric, criteria, seed); + metric_trees.push(Vertex::adapt_tree(root, None, data, metric)); } let metric_trees = metric_trees.try_into().unwrap_or_else(|v: Vec<_>| { unreachable!("Could not convert Vec to [Vertex; {M}]. Len was {}", v.len()) @@ -138,21 +143,20 @@ impl ChaodaTrainer { /// Create graphs for use in training the first epoch. #[allow(clippy::type_complexity)] - fn create_flat_graphs<'a, const N: usize, D: Dataset, S: Cluster>( + fn create_flat_graphs<'a, const N: usize, D: Dataset, S: Cluster>( &self, - datasets: &mut [D; N], - trees: &'a [[Vertex; M]; N], + datasets: &[D; N], + trees: &'a [[Vertex; M]; N], depths: &[usize], - ) -> [[Vec>; M]; N] { + ) -> [[Vec>; M]; N] { let mut graphs_vmn = Vec::new(); ftlog::info!("Creating flat graphs..."); - for (i, (data, roots)) in datasets.iter_mut().zip(trees.iter()).enumerate() { + for (i, (data, roots)) in datasets.iter().zip(trees.iter()).enumerate() { let mut graphs_vm = Vec::new(); ftlog::info!("Creating flat graphs for dataset {}/{N}...", i + 1); for (j, (metric, root)) in self.metrics.iter().zip(roots.iter()).enumerate() { - data.set_metric(metric.clone()); ftlog::info!( "Creating flat graphs for dataset {}/{N}, metric {}/{M}...", i + 1, @@ -167,7 +171,7 @@ impl ChaodaTrainer { j + 1, depths.len() ); - let cluster_scorer = |clusters: &[&Vertex]| { + let cluster_scorer = |clusters: &[&Vertex]| { clusters .iter() .map(|c| { @@ -179,7 +183,7 @@ impl ChaodaTrainer { }) .collect::>() }; - graphs_v.push(Graph::from_root(root, data, cluster_scorer, depth)); + graphs_v.push(Graph::from_root(root, data, metric, cluster_scorer, depth)); ftlog::info!( "Finished flat graphs for dataset {}/{N}, metric {}/{M}, depth {depth}/{}", i + 1, @@ -232,17 +236,17 @@ impl ChaodaTrainer { /// combination of `Dataset`, `Metric`, and pair of `MetaMLModel` and /// `GraphAlgorithm`. #[allow(clippy::type_complexity)] - fn create_graphs<'a, const N: usize, D: Dataset, S: Cluster>( + fn create_graphs<'a, const N: usize, D: Dataset, S: Cluster>( &self, - datasets: &mut [D; N], - trees: &'a [[Vertex; M]; N], + datasets: &[D; N], + trees: &'a [[Vertex; M]; N], trained_models: &[Vec; M], min_depth: usize, - ) -> [[Vec>; M]; N] { + ) -> [[Vec>; M]; N] { let mut graphs_vmn = Vec::new(); ftlog::info!("Creating graphs..."); - for (i, (data, trees)) in datasets.iter_mut().zip(trees.iter()).enumerate() { + for (i, (data, trees)) in datasets.iter().zip(trees.iter()).enumerate() { let mut graphs_vm = Vec::new(); ftlog::info!("Creating graphs for dataset {}/{N}...", i + 1); @@ -253,7 +257,6 @@ impl ChaodaTrainer { .zip(trained_models.iter()) .enumerate() { - data.set_metric(metric.clone()); ftlog::info!("Creating graphs for dataset {}/{N}, metric {}/{M}...", i + 1, j + 1); graphs_vm.push( trained_models @@ -267,7 +270,7 @@ impl ChaodaTrainer { k + 1 ); }) - .map(|(k, combination)| (k, combination.create_graph(root, data, min_depth))) + .map(|(k, combination)| (k, combination.create_graph(root, data, metric, min_depth))) .inspect(|(k, _)| { ftlog::info!( "Finished graph for dataset {}/{N}, metric {}/{M}, model {}/{M}...", @@ -314,9 +317,9 @@ impl ChaodaTrainer { /// /// The trained combinations and the mean roc-auc score. #[allow(clippy::type_complexity)] - fn train_epoch, S: Cluster>( + fn train_epoch>( &mut self, - graphs: &[[Vec>; M]; N], + graphs: &[[Vec>; M]; N], labels: &[Vec; N], ) -> Result<([Vec; M], f32), String> { let mut x = Vec::new(); @@ -384,14 +387,14 @@ impl ChaodaTrainer { /// - `N`: The number of datasets to train with. /// - `D`: The type of the datasets. /// - `S`: The type of the `Cluster` that were used to create the `Vertex` trees. - pub fn train, S: Cluster>( + pub fn train, S: Cluster>( &mut self, - datasets: &mut [D; N], - trees: &[[Vertex; M]; N], + datasets: &[D; N], + trees: &[[Vertex; M]; N], labels: &[Vec; N], min_depth: usize, num_epochs: usize, - ) -> Result, String> { + ) -> Result, String> { if min_depth == 0 { return Err("Minimum depth must be greater than 0.".to_string()); } @@ -427,30 +430,25 @@ impl ChaodaTrainer { ); } - Ok(Chaoda::new(self.metrics.clone(), trained_combinations)) + todo!() + // Ok(Chaoda::new(self.metrics.clone(), trained_combinations)) } } -impl ChaodaTrainer { - /// Parallel version of `create_trees`. - pub fn par_create_trees< - const N: usize, - D: ParDataset, - S: ParPartition, - C: (Fn(&S) -> bool) + Send + Sync, - >( +impl ChaodaTrainer { + /// Parallel version of [`ChaodaTrainer::create_trees`](crate::chaoda::training::ChaodaTrainer::create_trees). + pub fn par_create_trees, S: ParPartition, C: (Fn(&S) -> bool) + Send + Sync>( &self, - datasets: &mut [D; N], + datasets: &[D; N], criteria: &[[C; M]; N], seed: Option, - ) -> [[Vertex; M]; N] { + ) -> [[Vertex; M]; N] { let mut trees = Vec::new(); - for (data, criteria) in datasets.iter_mut().zip(criteria) { + for (data, criteria) in datasets.iter().zip(criteria) { let mut metric_trees = Vec::new(); for (metric, criteria) in self.metrics.iter().zip(criteria) { - data.set_metric(metric.clone()); - let root = S::par_new_tree(data, criteria, seed); - metric_trees.push(Vertex::par_adapt_tree(root, None)); + let root = S::par_new_tree(data, metric, criteria, seed); + metric_trees.push(Vertex::par_adapt_tree(root, None, data, metric)); } let metric_trees = metric_trees.try_into().unwrap_or_else(|v: Vec<_>| { unreachable!("Could not convert Vec to [Vertex; {M}]. Len was {}", v.len()) @@ -465,23 +463,22 @@ impl ChaodaTrainer { }) } - /// Parallel version of `create_flat_graphs`. + /// Parallel version of [`ChaodaTrainer::create_flat_graphs`](crate::chaoda::training::ChaodaTrainer::create_flat_graphs). #[allow(clippy::type_complexity)] - fn par_create_flat_graphs<'a, const N: usize, D: ParDataset, S: ParCluster>( + fn par_create_flat_graphs<'a, const N: usize, D: ParDataset, S: ParCluster>( &self, - datasets: &mut [D; N], - trees: &'a [[Vertex; M]; N], + datasets: &[D; N], + trees: &'a [[Vertex; M]; N], depths: &[usize], - ) -> [[Vec>; M]; N] { + ) -> [[Vec>; M]; N] { let mut graphs_vmn = Vec::new(); ftlog::info!("Creating flat graphs..."); - for (i, (data, roots)) in datasets.iter_mut().zip(trees.iter()).enumerate() { + for (i, (data, roots)) in datasets.iter().zip(trees.iter()).enumerate() { let mut graphs_vm = Vec::new(); ftlog::info!("Creating flat graphs for dataset {}/{N}...", i + 1); for (j, (metric, root)) in self.metrics.iter().zip(roots.iter()).enumerate() { - data.set_metric(metric.clone()); ftlog::info!( "Creating flat graphs for dataset {}/{N}, metric {}/{M}...", i + 1, @@ -500,7 +497,7 @@ impl ChaodaTrainer { ); }) .map(|&depth| { - let cluster_scorer = |clusters: &[&Vertex]| { + let cluster_scorer = |clusters: &[&Vertex]| { clusters .iter() .map(|c| { @@ -512,7 +509,7 @@ impl ChaodaTrainer { }) .collect::>() }; - let graph = Graph::par_from_root(root, data, cluster_scorer, depth); + let graph = Graph::par_from_root(root, data, metric, cluster_scorer, depth); (depth, graph) }) .inspect(|&(depth, _)| { @@ -555,19 +552,19 @@ impl ChaodaTrainer { graphs_vmn } - /// Parallel version of `create_graphs`. + /// Parallel version of [`ChaodaTrainer::create_graphs`](crate::chaoda::training::ChaodaTrainer::create_graphs). #[allow(clippy::type_complexity)] - fn par_create_graphs<'a, const N: usize, D: ParDataset, S: ParCluster>( + fn par_create_graphs<'a, const N: usize, D: ParDataset, S: ParCluster>( &self, - datasets: &mut [D; N], - trees: &'a [[Vertex; M]; N], + datasets: &[D; N], + trees: &'a [[Vertex; M]; N], trained_models: &[Vec; M], min_depth: usize, - ) -> [[Vec>; M]; N] { + ) -> [[Vec>; M]; N] { let mut graphs_vmn = Vec::new(); ftlog::info!("Creating graphs..."); - for (i, (data, trees)) in datasets.iter_mut().zip(trees.iter()).enumerate() { + for (i, (data, trees)) in datasets.iter().zip(trees.iter()).enumerate() { let mut graphs_vm = Vec::new(); ftlog::info!("Creating graphs for dataset {}/{N}...", i + 1); @@ -578,7 +575,6 @@ impl ChaodaTrainer { .zip(trained_models.iter()) .enumerate() { - data.set_metric(metric.clone()); ftlog::info!("Creating graphs for dataset {}/{N}, metric {}/{M}...", i + 1, j + 1); graphs_vm.push( trained_models @@ -592,7 +588,7 @@ impl ChaodaTrainer { k + 1 ); }) - .map(|(k, combination)| (k, combination.par_create_graph(root, data, min_depth))) + .map(|(k, combination)| (k, combination.par_create_graph(root, data, metric, min_depth))) .inspect(|(k, _)| { ftlog::info!( "Finished graph for dataset {}/{N}, metric {}/{M}, model {}/{M}...", @@ -628,11 +624,11 @@ impl ChaodaTrainer { graphs_vmn } - /// Parallel version of `train_epoch`. + /// Parallel version of [`ChaodaTrainer::train_epoch`](crate::chaoda::training::ChaodaTrainer::train_epoch). #[allow(clippy::type_complexity)] - fn par_train_epoch, S: ParCluster>( + fn par_train_epoch>( &mut self, - graphs: &[[Vec>; M]; N], + graphs: &[[Vec>; M]; N], labels: &[Vec; N], ) -> Result<([Vec; M], f32), String> { let mut x = Vec::new(); @@ -690,19 +686,19 @@ impl ChaodaTrainer { Ok((trained_combinations, roc_score)) } - /// Parallel version of `train`. + /// Parallel version of [`ChaodaTrainer::train`](crate::chaoda::training::ChaodaTrainer::train). /// /// # Errors /// - /// See `ChaodaTrainer::train`. - pub fn par_train, S: ParCluster>( + /// See [`ChaodaTrainer::train`](crate::chaoda::training::ChaodaTrainer::train). + pub fn par_train, S: ParCluster>( &mut self, - datasets: &mut [D; N], - trees: &[[Vertex; M]; N], + datasets: &[D; N], + trees: &[[Vertex; M]; N], labels: &[Vec; N], min_depth: usize, num_epochs: usize, - ) -> Result, String> { + ) -> Result, String> { if min_depth == 0 { return Err("Minimum depth must be greater than 0.".to_string()); } @@ -740,6 +736,7 @@ impl ChaodaTrainer { ); } - Ok(Chaoda::new(self.metrics.clone(), trained_combinations)) + todo!() + // Ok(Chaoda::new(self.metrics.clone(), trained_combinations)) } } diff --git a/crates/abd-clam/src/chaoda/training/trainable_smc.rs b/crates/abd-clam/src/chaoda/training/trainable_smc.rs new file mode 100644 index 000000000..94bf5a8d9 --- /dev/null +++ b/crates/abd-clam/src/chaoda/training/trainable_smc.rs @@ -0,0 +1,611 @@ +//! A trainable Single-Metric-CHAODA ensemble. + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + chaoda::{inference::TrainedCombination, Graph, TrainedSmc, Vertex}, + cluster::{ + adapter::{Adapter, ParAdapter}, + ParCluster, ParPartition, Partition, + }, + dataset::ParDataset, + metric::ParMetric, + utils, Cluster, Dataset, Metric, +}; + +use super::{GraphAlgorithm, TrainableCombination, TrainableMetaMlModel}; + +/// A trainable Single-Metric-CHAODA ensemble. +pub struct TrainableSmc(Vec); + +impl TrainableSmc { + /// Create a new trainable Single-Metric-CHAODA ensemble. + /// + /// # Arguments + /// + /// - `meta_ml_models`: The `MetaMLModel`s to train with. + /// - `graph_algorithms`: The `GraphAlgorithm`s to train with. + /// + /// # Returns + /// + /// The trainable Single-Metric-CHAODA ensemble using all pairs of the given + /// `MetaMLModel`s and `GraphAlgorithm`s. + #[must_use] + pub fn new(meta_ml_models: &[TrainableMetaMlModel], graph_algorithms: &[GraphAlgorithm]) -> Self { + Self( + meta_ml_models + .iter() + .flat_map(|meta_ml_model| { + graph_algorithms.iter().map(move |graph_algorithm| { + TrainableCombination::new(meta_ml_model.clone(), graph_algorithm.clone()) + }) + }) + .collect(), + ) + } + + /// Get the combinations of `MetaMLModel`s and `GraphAlgorithm`s to train + /// with. + #[must_use] + pub fn combinations(&self) -> &[TrainableCombination] { + &self.0 + } + + /// Create trees for use in training. + /// + /// # Arguments + /// + /// - `datasets`: The datasets to train with. + /// - `criteria`: The criteria to use for building a tree. + /// - `metric`: The metric to use for all datasets. + /// - `seed`: The seed to use for random number generation. + /// + /// # Returns + /// + /// An array of root vertices for the trees created for each dataset. + /// + /// # Type Parameters + /// + /// - `N`: The number of datasets, criteria, and root vertices. + /// - `I`: The type of the items in the datasets. + /// - `T`: The type of the distance values. + /// - `D`: The type of the datasets. + /// - `M`: The type of the metric. + /// - `S`: The type of the source cluster which will be added to `Vertex`. + /// - `C`: The type of the criteria function for building a tree. + pub fn create_trees( + &self, + datasets: &[D; N], + criteria: &[C; N], + metric: &M, + seed: Option, + ) -> [Vertex; N] + where + T: Number, + D: Dataset, + M: Metric, + S: Partition, + C: Fn(&S) -> bool, + { + let trees = datasets + .iter() + .enumerate() + .zip(criteria.iter()) + .inspect(|((i, _), _)| ftlog::info!("Creating tree for dataset {}/{N}", i + 1)) + .map(|((_, data), c)| { + let root = S::new_tree(data, metric, c, seed); + Vertex::adapt_tree(root, None, data, metric) + }) + .collect::>(); + trees + .try_into() + .unwrap_or_else(|e: Vec<_>| unreachable!("Expected {N} trees. Got {} instead", e.len())) + } + + /// Create graphs to use in the first epoch of training. + /// + /// These will be made of vertices with uniform depth from the trees. + /// + /// # Arguments + /// + /// - `datasets`: The datasets to train with. + /// - `metric`: The metric to use for all datasets. + /// - `roots`: The root vertices of the trees. + /// - `depths`: The depths of the vertices to create. + /// + /// # Returns + /// + /// An array of graphs (one at each depth) created for each dataset. + fn create_flat_graphs<'a, const N: usize, I, T, D, M, S>( + datasets: &[D; N], + metric: &M, + roots: &'a [Vertex; N], + depths: &[usize], + min_depth: usize, + ) -> [Vec>; N] + where + T: Number, + D: Dataset, + M: Metric, + S: Cluster, + { + let graphs = datasets + .iter() + .enumerate() + .zip(roots.iter()) + .inspect(|((i, _), _)| ftlog::info!("Creating flat graphs for dataset {}/{N}", i + 1)) + .map(|((_, data), root)| { + depths + .iter() + .enumerate() + .inspect(|(j, d)| ftlog::info!("Creating flat graph at depth={d} {}/{}", j + 1, depths.len())) + .map(|(_, &depth)| Graph::from_root_uniform_depth(root, data, metric, depth, min_depth)) + .collect::>() + }) + .collect::>(); + + graphs + .try_into() + .unwrap_or_else(|e: Vec<_>| unreachable!("Expected {N} graphs. Got {} instead", e.len())) + } + + /// Create graphs for use in training after the first epoch. + /// + /// # Arguments + /// + /// - `datasets`: The datasets to train with. + /// - `metric`: The metric to use for all datasets. + /// - `roots`: The root vertices of the trees. + /// - `trained_models`: The trained model combinations to use for selecting + /// the best clusters. + /// - `min_depth`: The minimum depth to create graphs for. + fn create_graphs<'a, const N: usize, I, T, D, M, S>( + datasets: &[D; N], + metric: &M, + roots: &'a [Vertex; N], + trained_combinations: &[TrainedCombination], + min_depth: usize, + ) -> [Vec>; N] + where + T: Number, + D: Dataset, + M: Metric, + S: Cluster, + { + let graphs = datasets + .iter() + .enumerate() + .zip(roots.iter()) + .inspect(|((i, _), _)| ftlog::info!("Creating meta-ml graphs for dataset {}/{N}", i + 1)) + .map(|((_, data), root)| { + trained_combinations + .iter() + .enumerate() + .inspect(|(j, combination)| { + ftlog::info!( + "Creating meta-ml graph {}/{} with {}", + j + 1, + trained_combinations.len(), + combination.name() + ); + }) + .map(|(_, combination)| combination.create_graph(root, data, metric, min_depth)) + .collect::>() + }) + .collect::>(); + + graphs + .try_into() + .unwrap_or_else(|e: Vec<_>| unreachable!("Expected {N} graphs. Got {} instead", e.len())) + } + + /// Train the ensemble for a single epoch. + /// + /// # Arguments + /// + /// - `graphs`: The graphs to train with. + /// - `labels`: The labels for the graphs. + /// + /// # Returns + /// + /// The trained model combinations and the mean roc-auc score. + /// + /// # Type Parameters + /// + /// - `N`: The number of datasets. + /// - `T`: The type of the distance values. + /// - `S`: The type of the source cluster which were adapted into `Vertex`. + /// + /// # Errors + /// + /// - If the meta-ml model fails to train. + fn train_epoch>( + &mut self, + graphs: &[Vec>; N], + labels: &[Vec; N], + ) -> Result<(Vec, f64), String> { + let num_combinations = self.0.len(); + + let (trained_combinations, roc_scores) = self + .0 + .iter_mut() + .enumerate() + .map(|(i, combination)| { + ftlog::info!( + "Training combination {}/{num_combinations} {}", + i + 1, + combination.name() + ); + + let (x, y, roc_scores_inner) = graphs.iter().enumerate().zip(labels.iter()).fold( + (Vec::new(), Vec::new(), Vec::new()), + |(mut x, mut y, mut roc_scores_inner), ((j, graphs_inner), labels_inner)| { + ftlog::info!("Training dataset {}/{N} with combination {}", j + 1, combination.name()); + + graphs_inner.iter().enumerate().for_each(|(k, graph)| { + ftlog::info!( + "Training graph {}/{} with combination {}", + k + 1, + graphs_inner.len(), + combination.name() + ); + let ([x_, y_], r) = combination + .data_from_graph(graph, labels_inner) + .unwrap_or_else(|e| unreachable!("{e}")); + x.extend(x_); + y.extend(y_); + roc_scores_inner.push(r); + }); + + (x, y, roc_scores_inner) + }, + ); + + ftlog::info!( + "Appending data for combination {}/{num_combinations} {}", + i + 1, + combination.name() + ); + let roc_score = utils::mean(&roc_scores_inner); + combination + .append_data(&x, &y, Some(roc_score)) + .unwrap_or_else(|e| unreachable!("{e}")); + + ftlog::info!( + "Taking training step for combination {}/{num_combinations} {}", + i + 1, + combination.name() + ); + + combination.train_step().map(|trained| (trained, roc_score)) + }) + .collect::, String>>()? + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let roc_score = utils::mean(&roc_scores); + Ok((trained_combinations, roc_score)) + } + + /// Train the ensemble. + /// + /// # Arguments + /// + /// - `datasets`: The datasets to train with. + /// - `metric`: The metric to use for all datasets. + /// - `roots`: The root vertices of the trees to use for training. + /// - `labels`: The labels for the datasets. + /// - `min_depth`: The minimum depth of clusters to consider for graphs. + /// - `depths`: The depths to create graphs for the first epoch. + /// - `num_epochs`: The number of epochs to train for. + /// + /// # Returns + /// + /// The trained ensemble. + /// + /// # Type Parameters + /// + /// - `N`: The number of datasets. + /// - `I`: The type of the items in the datasets. + /// - `T`: The type of the distance values. + /// - `D`: The type of the datasets. + /// - `M`: The type of the metric. + /// - `S`: The type of the source cluster which will be added to `Vertex`. + /// - `C`: The type of the criteria function for building a tree. + /// + /// # Errors + /// + /// - If the minimum depth is less than 1. + /// - If the number of number if labels is not equal to the cardinality of + /// the corresponding dataset. + /// - If and depth in `depths` is less than `min_depth`. + /// - If any meta-ml model fails to train. + #[allow(clippy::too_many_arguments)] + pub fn train( + &mut self, + datasets: &[D; N], + metric: &M, + roots: &[Vertex; N], + labels: &[Vec; N], + min_depth: usize, + depths: &[usize], + num_epochs: usize, + ) -> Result + where + T: Number, + D: Dataset, + M: Metric, + S: Cluster, + { + if min_depth == 0 { + return Err("Minimum depth must be greater than 0".to_string()); + } + + for (data, labels) in datasets.iter().zip(labels.iter()) { + if data.cardinality() != labels.len() { + return Err("Number of data points and labels must be equal".to_string()); + } + } + + for &d in depths { + if d < min_depth { + return Err("Depths must be no smaller than the minimum depth".to_string()); + } + } + + let flat_graphs = Self::create_flat_graphs(datasets, metric, roots, depths, min_depth); + let (mut trained_combinations, _) = self.train_epoch(&flat_graphs, labels)?; + + for _ in 1..num_epochs { + let graphs = Self::create_graphs(datasets, metric, roots, &trained_combinations, min_depth); + (trained_combinations, _) = self.train_epoch(&graphs, labels)?; + } + + Ok(TrainedSmc::new(trained_combinations)) + } + + /// Parallel version of [`TrainableSmc::create_trees`](crate::chaoda::training::TrainableSmc::create_trees). + pub fn par_create_trees( + &self, + datasets: &[D; N], + criteria: &[C; N], + metric: &M, + seed: Option, + ) -> [Vertex; N] + where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParPartition, + C: (Fn(&S) -> bool) + Send + Sync, + { + let trees = datasets + .par_iter() + .enumerate() + .zip(criteria.par_iter()) + .inspect(|((i, _), _)| ftlog::info!("Creating tree for dataset {}/{N}", i + 1)) + .map(|((_, data), c)| { + let root = S::par_new_tree(data, metric, c, seed); + Vertex::par_adapt_tree(root, None, data, metric) + }) + .collect::>(); + trees + .try_into() + .unwrap_or_else(|e: Vec<_>| unreachable!("Expected {N} trees. Got {} instead", e.len())) + } + + /// Parallel version of [`TrainableSmc::create_flat_graphs`](crate::chaoda::training::TrainableSmc::create_flat_graphs). + fn par_create_flat_graphs<'a, const N: usize, I, T, D, M, S>( + datasets: &[D; N], + metric: &M, + roots: &'a [Vertex; N], + depths: &[usize], + min_depth: usize, + ) -> [Vec>; N] + where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParCluster, + { + let graphs = datasets + .par_iter() + .enumerate() + .zip(roots.par_iter()) + .inspect(|((i, _), _)| ftlog::info!("Creating flat graphs for dataset {}/{N}", i + 1)) + .map(|((_, data), root)| { + depths + .par_iter() + .enumerate() + .inspect(|(j, d)| ftlog::info!("Creating flat graph at depth={d} {}/{}", j + 1, depths.len())) + .map(|(_, &depth)| Graph::par_from_root_uniform_depth(root, data, metric, depth, min_depth)) + .collect::>() + }) + .collect::>(); + + graphs + .try_into() + .unwrap_or_else(|e: Vec<_>| unreachable!("Expected {N} graphs. Got {} instead", e.len())) + } + + /// Parallel version of [`TrainableSmc::create_graphs`](crate::chaoda::training::TrainableSmc::create_graphs). + fn par_create_graphs<'a, const N: usize, I, T, D, M, S>( + datasets: &[D; N], + metric: &M, + roots: &'a [Vertex; N], + trained_combinations: &[TrainedCombination], + min_depth: usize, + ) -> [Vec>; N] + where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParCluster, + { + let graphs = datasets + .par_iter() + .enumerate() + .zip(roots.par_iter()) + .inspect(|((i, _), _)| ftlog::info!("Creating meta-ml graphs for dataset {}/{N}", i + 1)) + .map(|((_, data), root)| { + trained_combinations + .par_iter() + .enumerate() + .inspect(|(j, combination)| { + ftlog::info!( + "Creating meta-ml graph {}/{} with {}", + j + 1, + trained_combinations.len(), + combination.name() + ); + }) + .map(|(_, combination)| combination.par_create_graph(root, data, metric, min_depth)) + .collect::>() + }) + .collect::>(); + + graphs + .try_into() + .unwrap_or_else(|e: Vec<_>| unreachable!("Expected {N} graphs. Got {} instead", e.len())) + } + + /// Parallel version of [`TrainableSmc::train_epoch`](crate::chaoda::training::TrainableSmc::train_epoch). + /// + /// # Errors + /// + /// - See [`TrainableSmc::train_epoch`](crate::chaoda::training::TrainableSmc::train_epoch). + fn par_train_epoch>( + &mut self, + graphs: &[Vec>; N], + labels: &[Vec; N], + ) -> Result<(Vec, f64), String> { + let num_combinations = self.0.len(); + + let (trained_combinations, roc_scores) = self + .0 + .par_iter_mut() + .enumerate() + .map(|(i, combination)| { + ftlog::info!( + "Training combination {}/{num_combinations} {}", + i + 1, + combination.name() + ); + + let (x, y, roc_scores_inner) = graphs + .par_iter() + .enumerate() + .zip(labels.par_iter()) + .fold( + || (Vec::new(), Vec::new(), Vec::new()), + |(mut x, mut y, mut roc_scores_inner), ((j, graphs_inner), labels_inner)| { + ftlog::info!("Training dataset {}/{N} with combination {}", j + 1, combination.name()); + + graphs_inner.iter().enumerate().for_each(|(k, graph)| { + ftlog::info!( + "Training graph {}/{} with combination {}", + k + 1, + graphs_inner.len(), + combination.name() + ); + let ([x_, y_], r) = combination + .data_from_graph(graph, labels_inner) + .unwrap_or_else(|e| unreachable!("{e}")); + x.extend(x_); + y.extend(y_); + roc_scores_inner.push(r); + }); + + (x, y, roc_scores_inner) + }, + ) + .reduce( + || (Vec::new(), Vec::new(), Vec::new()), + |(mut x1, mut y1, mut roc_scores_inner1), (x2, y2, roc_scores_inner2)| { + x1.extend(x2); + y1.extend(y2); + roc_scores_inner1.extend(roc_scores_inner2); + (x1, y1, roc_scores_inner1) + }, + ); + + ftlog::info!( + "Appending data for combination {}/{num_combinations} {}", + i + 1, + combination.name() + ); + let roc_score = utils::mean(&roc_scores_inner); + combination + .append_data(&x, &y, Some(roc_score)) + .unwrap_or_else(|e| unreachable!("{e}")); + + ftlog::info!( + "Taking training step for combination {}/{num_combinations} {}", + i + 1, + combination.name() + ); + + combination.train_step().map(|trained| (trained, roc_score)) + }) + .collect::, String>>()? + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let roc_score = utils::mean(&roc_scores); + Ok((trained_combinations, roc_score)) + } + + /// Parallel version of [`TrainableSmc::train`](crate::chaoda::training::TrainableSmc::train). + /// + /// # Errors + /// + /// - See [`TrainableSmc::train`](crate::chaoda::training::TrainableSmc::train). + #[allow(clippy::too_many_arguments)] + pub fn par_train( + &mut self, + datasets: &[D; N], + metric: &M, + roots: &[Vertex; N], + labels: &[Vec; N], + min_depth: usize, + depths: &[usize], + num_epochs: usize, + ) -> Result + where + I: Send + Sync, + T: Number, + D: ParDataset, + M: ParMetric, + S: ParCluster, + { + if min_depth == 0 { + return Err("Minimum depth must be greater than 0".to_string()); + } + + for (data, labels) in datasets.iter().zip(labels.iter()) { + if data.cardinality() != labels.len() { + return Err("Number of data points and labels must be equal".to_string()); + } + } + + for &d in depths { + if d < min_depth { + return Err("Depths must be no smaller than the minimum depth".to_string()); + } + } + + let flat_graphs = Self::par_create_flat_graphs(datasets, metric, roots, depths, min_depth); + let (mut trained_combinations, _) = self.par_train_epoch(&flat_graphs, labels)?; + + for _ in 1..num_epochs { + let graphs = Self::par_create_graphs(datasets, metric, roots, &trained_combinations, min_depth); + (trained_combinations, _) = self.par_train_epoch(&graphs, labels)?; + } + + Ok(TrainedSmc::new(trained_combinations)) + } +} diff --git a/crates/abd-clam/src/core/cluster/adapter.rs b/crates/abd-clam/src/core/cluster/adapter.rs index 84f74cf12..da6373862 100644 --- a/crates/abd-clam/src/core/cluster/adapter.rs +++ b/crates/abd-clam/src/core/cluster/adapter.rs @@ -3,47 +3,84 @@ use distances::Number; use rayon::prelude::*; -use crate::{dataset::ParDataset, Dataset}; +use crate::{dataset::ParDataset, metric::ParMetric, utils, Dataset, Metric}; use super::{Ball, Cluster, ParCluster}; +/// Used for adapting a `Ball` into another `Cluster`. +/// +/// # Type Parameters: +/// +/// - I: The items. +/// - T: The distance values. +/// - D: The `Dataset` that the tree was originally built on. +/// - S: The `Cluster` that the tree was originally built on. +/// +/// # Examples +/// +/// See: +/// +/// - [`Offset`](crate::cakes::Offset) +/// - [`SquishCosts`](crate::pancakes::SquishCosts) +pub trait Params, S: Cluster>: Default { + /// Given the `S` that was adapted into a `Cluster`, returns parameters + /// to use for adapting the children of `S`. + #[must_use] + fn child_params>(&self, children: &[S], data: &D, metric: &M) -> Vec; +} + +/// Parallel version of [`Params`](crate::core::cluster::adapter::Params). +/// +/// # Examples +/// +/// See: +/// +/// - [`Offset`](crate::cakes::Offset) +/// - [`SquishCosts`](crate::pancakes::SquishCosts) +pub trait ParParams, S: ParCluster>: + Params + Send + Sync +{ + /// Parallel version of [`Params::child_params`](crate::core::cluster::adapter::Params::child_params). + #[must_use] + fn par_child_params>(&self, children: &[S], data: &D, metric: &M) -> Vec; +} + /// A trait for adapting a `Ball` into another `Cluster`. +/// +/// # Examples +/// +/// See: +/// +/// - [`PermutedBall`](crate::cakes::PermutedBall) +/// - [`SquishyBall`](crate::pancakes::SquishyBall) #[allow(clippy::module_name_repetitions)] -pub trait BallAdapter, Dout: Dataset, P: Params>>: - Cluster +pub trait BallAdapter, Dout: Dataset, P: Params>>: + Cluster + Sized { /// Adapts this `Cluster` from a `Ball` tree. - fn from_ball_tree(ball: Ball, data: Din) -> (Self, Dout); + fn from_ball_tree>(ball: Ball, data: Din, metric: &M) -> (Self, Dout); } -/// Parallel version of the `BallAdapter` trait. +/// Parallel version of the [`BallAdapter`](crate::core::cluster::adapter::BallAdapter) +/// trait. +/// +/// # Examples +/// +/// See: +/// +/// - [`PermutedBall`](crate::cakes::PermutedBall) +/// - [`SquishyBall`](crate::pancakes::SquishyBall) #[allow(clippy::module_name_repetitions)] pub trait ParBallAdapter< I: Send + Sync, - U: Number, - Din: ParDataset, - Dout: ParDataset, - P: ParParams>, ->: ParCluster + BallAdapter + T: Number, + Din: ParDataset, + Dout: ParDataset, + P: ParParams>, +>: ParCluster + BallAdapter { - /// Parallel version of the `from_ball_tree` method. - fn par_from_ball_tree(ball: Ball, data: Din) -> (Self, Dout); -} - -/// A trait for the parameters to use for adapting a `Ball` into another `Cluster`. -/// -/// # Type Parameters: -/// -/// - I: The type of instances. -/// - U: The type of distance values. -/// - Din: The type of `Dataset` that the tree was originally built on. -/// - Dout: The type of the `Dataset` that the adapted tree will use. -/// - S: The type of `Cluster` that the tree was originally built on. -pub trait Params, Dout: Dataset, S: Cluster>: Default { - /// Given the `S` that was adapted into a `Cluster`, returns parameters - /// to use for adapting the children of `S`. - #[must_use] - fn child_params(&self, children: &[S]) -> Vec; + /// Parallel version of [`BallAdapter::from_ball_tree`](crate::core::cluster::adapter::BallAdapter::from_ball_tree). + fn par_from_ball_tree>(ball: Ball, data: Din, metric: &M) -> (Self, Dout); } /// A trait for adapting one `Cluster` type into another `Cluster` type. @@ -56,29 +93,31 @@ pub trait Params, Dout: Dataset, S: Clust /// /// # Type Parameters: /// -/// - I: The type of instances. -/// - U: The type of distance values. -/// - Din: The type of `Dataset` that the tree was originally built on. -/// - Dout: The type of the `Dataset` that the adapted tree will use. -/// - S: The type of `Cluster` that the tree was originally built on. -/// - P: The type of `Params` to use for adapting the tree. -pub trait Adapter< - I, - U: Number, - Din: Dataset, - Dout: Dataset, - S: Cluster, - P: Params, ->: Cluster +/// - I: The items. +/// - T: The distance values. +/// - Din: The `Dataset` that the tree was originally built on. +/// - Dout: The the `Dataset` that the adapted tree will use. +/// - S: The `Cluster` that the tree was originally built on. +/// - P: The `Params` to use for adapting the tree. +/// +/// # Examples +/// +/// See: +/// +/// - [`PermutedBall`](crate::cakes::PermutedBall) +/// - [`SquishyBall`](crate::pancakes::SquishyBall) +pub trait Adapter, Dout: Dataset, S: Cluster, P: Params>: + Cluster + Sized { - /// Creates a new `Cluster` that was adapted from a `S` and a list of children. - fn new_adapted(source: S, children: Vec<(usize, U, Box)>, params: P) -> Self; + /// Creates a new `Cluster` that was adapted from a `S` and a list of + /// children. + fn new_adapted>(source: S, children: Vec>, params: P, data: &Din, metric: &M) -> Self; /// Performs a task after recursively traversing the tree. fn post_traversal(&mut self); - /// Returns the `Cluster` that was adapted into this `Cluster`. This should not - /// have any children. + /// Returns the `Cluster` that was adapted into this `Cluster`. This should + /// not have any children. fn source(&self) -> &S; /// Returns the `Ball` mutably that was adapted into this `Cluster`. This @@ -98,173 +137,160 @@ pub trait Adapter< /// - `source`: The `S` to adapt. /// - `params`: The parameters to use for adapting `S`. If `None`, assume /// that `S` is a root `Cluster` and use the default parameters. - fn adapt_tree(source: S, params: Option

) -> Self { + /// - `data`: The `Dataset` that the tree was built on. + /// - `metric`: The `Metric` to use for distance calculations. + fn adapt_tree>(source: S, params: Option

, data: &Din, metric: &M) -> Self { let params = params.unwrap_or_default(); - let mut cluster = Self::traverse(source, params); + let mut cluster = Self::traverse(source, params, data, metric); cluster.post_traversal(); cluster } - /// Recursively adapts a tree of `S`s into a `Cluster` without any pre- or post- - /// traversal operations. - fn traverse(mut source: S, params: P) -> Self { - let children = source.take_children(); + /// Recursively adapts a tree of `S`s into a `Cluster` without any pre- or + /// post- traversal operations. + fn traverse>(mut source: S, params: P, data: &Din, metric: &M) -> Self { + let children = source.take_children().into_iter().map(|c| *c).collect::>(); if children.is_empty() { - Self::new_adapted(source, Vec::new(), params) + Self::new_adapted(source, Vec::new(), params, data, metric) } else { - let (arg_extrema, others) = children - .into_iter() - .map(|(a, b, c)| (a, (b, c))) - .unzip::<_, _, Vec<_>, Vec<_>>(); - let (extents, children) = others.into_iter().map(|(e, c)| (e, *c)).unzip::<_, _, Vec<_>, Vec<_>>(); let children = params - .child_params(&children) + .child_params(&children, data, metric) .into_iter() .zip(children) - .map(|(p, c)| Self::adapt_tree(c, Some(p))) - .zip(arg_extrema) - .zip(extents) - .map(|((c, i), d)| (i, d, Box::new(c))) + .map(|(p, c)| Self::adapt_tree(c, Some(p), data, metric)) + .map(Box::new) .collect(); - Self::new_adapted(source, children, params) + Self::new_adapted(source, children, params, data, metric) } } - /// Adapts the tree of `S`s into this `Cluster` in a such a way that we bypass - /// the recursion limit in Rust. - fn adapt_tree_iterative(mut source: S, params: Option

) -> Self { - let target_depth = source.depth() + source.max_recursion_depth(); + /// Adapts the tree of `S`s into this `Cluster` in a such a way that we + /// bypass the recursion limit in Rust. + fn adapt_tree_iterative>(mut source: S, params: Option

, data: &Din, metric: &M) -> Self { + let target_depth = source.depth() + utils::max_recursion_depth(); let trimmings = source.trim_at_depth(target_depth); - let params = params.unwrap_or_default(); - let mut root = Self::traverse(source, params); + let mut root = Self::adapt_tree(source, params, data, metric); let leaf_params = root .leaves() .into_iter() .filter(|l| l.depth() == target_depth) - .map(Self::params); + .map(Self::params) + .collect::>(); let trimmings = trimmings .into_iter() .zip(leaf_params) .map(|(children, params)| { - let (others, children) = children - .into_iter() - .map(|(i, d, c)| ((i, d), *c)) - .unzip::<_, _, Vec<_>, Vec<_>>(); - + let children = children.into_iter().map(|c| *c).collect::>(); params - .child_params(&children) + .child_params(&children, data, metric) .into_iter() .zip(children) - .zip(others) - .map(|((p, c), (i, d))| { - let c = Self::adapt_tree_iterative(c, Some(p)); - (i, d, Box::new(c)) - }) + .map(|(p, c)| Self::adapt_tree_iterative(c, Some(p), data, metric)) + .map(Box::new) .collect::>() }) .collect::>(); root.graft_at_depth(target_depth, trimmings); - root.post_traversal(); root } /// Recover the source `Cluster` tree that was adapted into this `Cluster`. fn recover_source_tree(mut self) -> S { - let indices = self.source().indices().collect(); + let indices = self.source().indices(); let children = self .take_children() .into_iter() - .map(|(i, d, c)| (i, d, Box::new(c.recover_source_tree()))) + .map(|c| c.recover_source_tree()) + .map(Box::new) .collect(); let mut source = self.take_source(); - source.set_indices(indices); + source.set_indices(&indices); source.set_children(children); source } } -/// Parallel version of the `Params` trait. -pub trait ParParams, Dout: ParDataset, S: ParCluster>: - Params + Send + Sync -{ - /// Parallel version of the `child_params` method. - #[must_use] - fn par_child_params(&self, children: &[S]) -> Vec; -} - -/// Parallel version of the `Adapter` trait. +/// Parallel version of [`Adapter`](crate::core::cluster::adapter::Adapter). +/// +/// # Examples +/// +/// See: +/// +/// - [`PermutedBall`](crate::cakes::PermutedBall) +/// - [`SquishyBall`](crate::pancakes::SquishyBall) #[allow(clippy::module_name_repetitions)] pub trait ParAdapter< I: Send + Sync, - U: Number, - Din: ParDataset, - Dout: ParDataset, - S: ParCluster, - P: ParParams, ->: ParCluster + Adapter + T: Number, + Din: ParDataset, + Dout: ParDataset, + S: ParCluster, + P: ParParams, +>: ParCluster + Adapter { - /// Parallel version of the `post_traversal` method. - fn par_post_traversal(&mut self); - - /// Parallel version of the `adapt` method. - fn par_adapt_tree(mut source: S, params: Option

) -> Self { - let children = source.take_children(); + /// Parallel version of [`Adapter::new_adapted`](crate::core::cluster::adapter::Adapter::new_adapted). + fn par_new_adapted>( + source: S, + children: Vec>, + params: P, + data: &Din, + metric: &M, + ) -> Self; + + /// Parallel version of [`Adapter::adapt_tree`](crate::core::cluster::adapter::Adapter::adapt_tree). + fn par_adapt_tree>(mut source: S, params: Option

, data: &Din, metric: &M) -> Self { + let children = source.take_children().into_iter().map(|c| *c).collect::>(); let params = params.unwrap_or_default(); let mut cluster = if children.is_empty() { - Self::new_adapted(source, Vec::new(), params) + Self::par_new_adapted(source, Vec::new(), params, data, metric) } else { - let (arg_extrema, others) = children - .into_iter() - .map(|(a, b, c)| (a, (b, c))) - .unzip::<_, _, Vec<_>, Vec<_>>(); - let (extents, children) = others.into_iter().map(|(e, c)| (e, *c)).unzip::<_, _, Vec<_>, Vec<_>>(); let children = params - .child_params(&children) + .child_params(&children, data, metric) .into_par_iter() .zip(children) - .map(|(p, c)| Self::par_adapt_tree(c, Some(p))) - .zip(arg_extrema) - .zip(extents) - .map(|((c, i), d)| (i, d, Box::new(c))) - .collect::>(); - Self::new_adapted(source, children, params) + .map(|(p, c)| Self::par_adapt_tree(c, Some(p), data, metric)) + .map(Box::new) + .collect(); + Self::par_new_adapted(source, children, params, data, metric) }; - cluster.par_post_traversal(); + cluster.post_traversal(); cluster } - /// Recover the source `Cluster` tree that was adapted into this `Cluster`. + /// Parallel version of [`Adapter::recover_source_tree`](crate::core::cluster::adapter::Adapter::recover_source_tree). fn par_recover_source_tree(mut self) -> S { - let indices = self.source().indices().collect(); + let indices = self.source().indices(); let children = self .take_children() .into_par_iter() - .map(|(i, d, c)| (i, d, Box::new(c.par_recover_source_tree()))) + .map(|c| c.par_recover_source_tree()) + .map(Box::new) .collect(); let mut source = self.take_source(); - source.set_indices(indices); + source.set_indices(&indices); source.set_children(children); source } - /// Adapts the tree of `S`s into this `Cluster` in a such a way that we bypass - /// the recursion limit in Rust. - fn par_adapt_tree_iterative(mut source: S, params: Option

) -> Self { - let target_depth = source.depth() + source.max_recursion_depth(); - let children = source.trim_at_depth(target_depth); - let mut root = Self::adapt_tree(source, params); + /// Parallel version of [`Adapter::adapt_tree_iterative`](crate::core::cluster::adapter::Adapter::adapt_tree_iterative). + fn par_adapt_tree_iterative>(mut source: S, params: Option

, data: &Din, metric: &M) -> Self { + let target_depth = source.depth() + utils::max_recursion_depth(); + let trimmings = source.trim_at_depth(target_depth); + + let mut root = Self::par_adapt_tree(source, params, data, metric); + let leaf_params = root .leaves() .into_par_iter() @@ -272,28 +298,23 @@ pub trait ParAdapter< .map(Self::params) .collect::>(); - let children = children + let trimmings = trimmings .into_par_iter() .zip(leaf_params) .map(|(children, params)| { - let (others, children) = children - .into_iter() - .map(|(i, d, c)| ((i, d), *c)) - .unzip::<_, _, Vec<_>, Vec<_>>(); + let children = children.into_iter().map(|c| *c).collect::>(); params - .child_params(&children) + .child_params(&children, data, metric) .into_par_iter() .zip(children) - .zip(others) - .map(|((p, c), (i, d))| { - let c = Self::par_adapt_tree_iterative(c, Some(p)); - (i, d, Box::new(c)) - }) + .map(|(p, c)| Self::par_adapt_tree_iterative(c, Some(p), data, metric)) + .map(Box::new) .collect::>() }) .collect::>(); - root.graft_at_depth(target_depth, children); + root.graft_at_depth(target_depth, trimmings); + root } } diff --git a/crates/abd-clam/src/core/cluster/balanced_ball.rs b/crates/abd-clam/src/core/cluster/balanced_ball.rs index f4f3cfaf9..a6c9a45fb 100644 --- a/crates/abd-clam/src/core/cluster/balanced_ball.rs +++ b/crates/abd-clam/src/core/cluster/balanced_ball.rs @@ -1,181 +1,258 @@ -//! A `Cluster` that provides a balanced clustering. +//! A `BalancedBall` is a data structure that represents a balanced binary tree. use distances::Number; use rayon::prelude::*; -use serde::{Deserialize, Serialize}; -use crate::{dataset::ParDataset, Dataset}; +use crate::{dataset::ParDataset, metric::ParMetric, Dataset, Metric}; -use super::{partition::ParPartition, Ball, Cluster, ParCluster, Partition}; +use super::{Ball, Cluster, ParCluster, ParPartition, Partition}; -/// A `Cluster` that provides a balanced clustering. -#[derive(Clone, Serialize, Deserialize)] -pub struct BalancedBall> { - /// The inner `Ball` of the `BalancedBall`. - pub(crate) ball: Ball, - /// The children of the `BalancedBall`. - pub(crate) children: Vec<(usize, U, Box)>, +/// A `BalancedBall` is a data structure that represents a balanced binary tree. +#[derive(Clone)] +pub struct BalancedBall(Ball, Vec>); + +impl BalancedBall { + /// Converts the `BalancedBall` into a `Ball`. + pub fn into_ball(mut self) -> Ball { + if !self.1.is_empty() { + let children = self.1.into_iter().map(|c| c.into_ball()).map(Box::new).collect(); + self.0.set_children(children); + } + self.0 + } } -impl> core::fmt::Debug for BalancedBall { +impl core::fmt::Debug for BalancedBall { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("Ball") - .field("depth", &self.ball.depth()) - .field("cardinality", &self.ball.cardinality()) - .field("radius", &self.ball.radius()) - .field("lfd", &self.ball.lfd()) - .field("arg_center", &self.ball.arg_center()) - .field("arg_radial", &self.ball.arg_radial()) - .field("indices", &self.ball.indices) - .field("children", &self.children.is_empty()) + f.debug_struct("BalancedBall") + .field("depth", &self.depth()) + .field("cardinality", &self.cardinality()) + .field("radius", &self.radius()) + .field("lfd", &self.lfd()) + .field("arg_center", &self.arg_center()) + .field("arg_radial", &self.arg_radial()) + .field("indices", &self.indices()) + .field("extents", &self.extents()) + .field("children", &!self.is_leaf()) .finish() } } -impl> PartialEq for BalancedBall { +impl PartialEq for BalancedBall { fn eq(&self, other: &Self) -> bool { - self.ball.eq(&other.ball) + self.0 == other.0 && self.1 == other.1 } } -impl> Eq for BalancedBall {} +impl Eq for BalancedBall {} -impl> PartialOrd for BalancedBall { - fn partial_cmp(&self, other: &Self) -> Option { +impl PartialOrd for BalancedBall { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl> Ord for BalancedBall { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.ball.cmp(&other.ball) +impl Ord for BalancedBall { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) } } -impl> std::hash::Hash for BalancedBall { - fn hash(&self, state: &mut H) { - self.ball.hash(state); +impl core::hash::Hash for BalancedBall { + fn hash(&self, state: &mut H) { + self.0.hash(state); } } -impl> Cluster for BalancedBall { +impl Cluster for BalancedBall { fn depth(&self) -> usize { - self.ball.depth() + self.0.depth() } fn cardinality(&self) -> usize { - self.ball.cardinality() + self.0.cardinality() } fn arg_center(&self) -> usize { - self.ball.arg_center() + self.0.arg_center() } fn set_arg_center(&mut self, arg_center: usize) { - self.ball.set_arg_center(arg_center); + self.0.set_arg_center(arg_center); } - fn radius(&self) -> U { - self.ball.radius() + fn radius(&self) -> T { + self.0.radius() } fn arg_radial(&self) -> usize { - self.ball.arg_radial() + self.0.arg_radial() } fn set_arg_radial(&mut self, arg_radial: usize) { - self.ball.set_arg_radial(arg_radial); + self.0.set_arg_radial(arg_radial); } fn lfd(&self) -> f32 { - self.ball.lfd() + self.0.lfd() } - fn indices(&self) -> impl Iterator + '_ { - self.ball.indices() + fn contains(&self, idx: usize) -> bool { + self.0.contains(idx) } - fn set_indices(&mut self, indices: Vec) { - self.ball.set_indices(indices); + fn indices(&self) -> Vec { + self.0.indices() } - fn children(&self) -> &[(usize, U, Box)] { - &self.children + fn set_indices(&mut self, indices: &[usize]) { + self.0.set_indices(indices); } - fn children_mut(&mut self) -> &mut [(usize, U, Box)] { - &mut self.children + fn extents(&self) -> &[(usize, T)] { + self.0.extents() } - fn set_children(&mut self, children: Vec<(usize, U, Box)>) { - self.children = children; + fn extents_mut(&mut self) -> &mut [(usize, T)] { + self.0.extents_mut() } - fn take_children(&mut self) -> Vec<(usize, U, Box)> { - core::mem::take(&mut self.children) + fn add_extent(&mut self, idx: usize, extent: T) { + self.0.add_extent(idx, extent); } - fn distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)> { - self.ball.distances_to_query(data, query) + fn take_extents(&mut self) -> Vec<(usize, T)> { + self.0.take_extents() } - fn is_descendant_of(&self, other: &Self) -> bool { - self.ball.is_descendant_of(&other.ball) + fn children(&self) -> Vec<&Self> { + self.1.iter().map(AsRef::as_ref).collect() } -} -impl> ParCluster for BalancedBall { - fn par_distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)> { - self.ball.par_distances_to_query(data, query) + fn children_mut(&mut self) -> Vec<&mut Self> { + self.1.iter_mut().map(AsMut::as_mut).collect() + } + + fn set_children(&mut self, children: Vec>) { + self.1 = children; } -} -impl> Partition for BalancedBall { - fn new(data: &D, indices: &[usize], depth: usize, seed: Option) -> Self { - let ball = Ball::new(data, indices, depth, seed); - let children = Vec::new(); - Self { ball, children } + fn take_children(&mut self) -> Vec> { + core::mem::take(&mut self.1) } - fn find_extrema(&self, data: &D) -> Vec { - self.ball.find_extrema(data) + fn is_descendant_of(&self, other: &Self) -> bool { + self.0.is_descendant_of(&other.0) } +} - fn split_by_extrema(&self, data: &D, extrema: &[usize]) -> (Vec>, Vec) { - let mut instances = self.indices().filter(|&i| !extrema.contains(&i)).collect::>(); +impl ParCluster for BalancedBall { + fn par_indices(&self) -> impl rayon::prelude::ParallelIterator { + self.0.par_indices() + } +} - // Calculate the number of instances per child for a balanced split - let num_per_child = instances.len() / extrema.len(); - let last_child_size = if instances.len() % extrema.len() == 0 { - num_per_child +impl Partition for BalancedBall { + fn new, M: Metric>( + data: &D, + metric: &M, + indices: &[usize], + depth: usize, + seed: Option, + ) -> Result { + Ball::new(data, metric, indices, depth, seed).map(|ball| Self(ball, Vec::new())) + } + + fn find_extrema, M: Metric>(&mut self, data: &D, metric: &M) -> Vec { + self.0.find_extrema(data, metric) + } + + #[allow(clippy::similar_names)] + fn split_by_extrema, M: Metric>( + &self, + data: &D, + metric: &M, + extrema: &[usize], + ) -> (Vec>, Vec) { + let [l, r] = [extrema[0], extrema[1]]; + let lr = data.one_to_one(l, r, metric); + + let items = self + .indices() + .into_iter() + .filter(|&i| !(i == l || i == r)) + .collect::>(); + + // Find the distances from each extremum to each item. + let l_distances = data.one_to_many(l, &items, metric); + let r_distances = data.one_to_many(r, &items, metric); + + let child_stacks = if metric.obeys_triangle_inequality() { + let lr = lr.as_f32(); + + // Find the distance from `l` to each item projected onto the line + // connecting `l` and `r`. + let lr_distances = { + let mut lr_distances = l_distances + .map(|(a, d)| (a, d.as_f32())) + .zip(r_distances.map(|(_, d)| d.as_f32())) + .map(|((a, al), ar)| { + let cos = ar.mul_add(-ar, lr.mul_add(lr, al.powi(2))) / (2.0 * al * lr); + (a, al * cos) + }) + .collect::>(); + lr_distances.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + lr_distances + .into_iter() + .map(|(a, d)| (a, T::from(d))) + .collect::>() + }; + + // Half of the items will be assigned to the left child and the + // other half to the right child. + let mid = if lr_distances.len() % 2 == 0 { + lr_distances.len() / 2 + } else { + 1 + lr_distances.len() / 2 + }; + let (ls, rs) = lr_distances.split_at(mid); + let l_stack = core::iter::once((l, T::ZERO)) + .chain(ls.iter().copied()) + .collect::>(); + let r_stack = core::iter::once((r, T::ZERO)) + .chain(rs.iter().copied()) + .collect::>(); + + vec![l_stack, r_stack] } else { - num_per_child + 1 + // If the metric does not obey the triangle inequality, we just sort + // the items by their distance to `l`. + let l_distances = { + let mut l_distances = l_distances.collect::>(); + l_distances.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + l_distances + }; + + // Half of the items will be assigned to the left child and the + // other half to the right child. + let (ls, rs) = l_distances.split_at(l_distances.len() / 2); + let l_stack = core::iter::once((l, T::ZERO)) + .chain(ls.iter().copied()) + .collect::>(); + let r_stack = core::iter::once((r, T::ZERO)) + .chain(rs.iter().map(|&(a, d)| (a, lr - d))) + .collect::>(); + + vec![l_stack, r_stack] }; - let child_sizes = - core::iter::once(last_child_size).chain(core::iter::repeat(num_per_child).take(extrema.len() - 1)); - // Initialize the child stacks with the extrema - let mut child_stacks = extrema.iter().map(|&e| vec![(e, U::ZERO)]).collect::>(); - for (child_stack, s) in child_stacks.iter_mut().zip(child_sizes) { - // Calculate the distances to the instances from the extremum - let mut distances = Dataset::one_to_many(data, child_stack[0].0, &instances); - distances.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(core::cmp::Ordering::Less)); - - // Remove the closest instances from the distances and add them to the child stack - child_stack.extend(distances.split_off(instances.len() - s)); - - // Update the instances for the next child - instances = distances.into_iter().map(|(i, _)| i).collect(); - } - - // Unzip the child stacks into the indices and calculate the extent of each child child_stacks .into_iter() .map(|stack| { let (indices, distances) = stack.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); let extent = distances .into_iter() - .max_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)) + .max_by(Number::total_cmp) .unwrap_or_else(|| unreachable!("Cannot find the maximum distance")); (indices, extent) }) @@ -183,151 +260,111 @@ impl> Partition for BalancedBall> ParPartition for BalancedBall { - fn par_new(data: &D, indices: &[usize], depth: usize, seed: Option) -> Self { - let ball = Ball::par_new(data, indices, depth, seed); - let children = Vec::new(); - Self { ball, children } - } - - fn par_find_extrema(&self, data: &D) -> Vec { - self.ball.par_find_extrema(data) - } - - fn par_split_by_extrema(&self, data: &D, extrema: &[usize]) -> (Vec>, Vec) { - let mut instances = self.indices().filter(|&i| !extrema.contains(&i)).collect::>(); - - // Calculate the number of instances per child for a balanced split - let num_per_child = instances.len() / extrema.len(); - let last_child_size = if instances.len() % extrema.len() == 0 { - num_per_child +impl ParPartition for BalancedBall { + fn par_new, M: ParMetric>( + data: &D, + metric: &M, + indices: &[usize], + depth: usize, + seed: Option, + ) -> Result { + Ball::par_new(data, metric, indices, depth, seed).map(|ball| Self(ball, Vec::new())) + } + + fn par_find_extrema, M: ParMetric>( + &mut self, + data: &D, + metric: &M, + ) -> Vec { + self.0.par_find_extrema(data, metric) + } + + #[allow(clippy::similar_names)] + fn par_split_by_extrema, M: ParMetric>( + &self, + data: &D, + metric: &M, + extrema: &[usize], + ) -> (Vec>, Vec) { + let [l, r] = [extrema[0], extrema[1]]; + let lr = data.par_one_to_one(l, r, metric); + + let items = self.par_indices().filter(|&i| !(i == l || i == r)).collect::>(); + + // Find the distances from each extremum to each item. + let l_distances = data.par_one_to_many(l, &items, metric).collect::>(); + let r_distances = data.par_one_to_many(r, &items, metric).collect::>(); + + let child_stacks = if metric.obeys_triangle_inequality() { + let lr = lr.as_f32(); + + // Find the distance from `l` to each item projected onto the line + // connecting `l` and `r`. + let lr_distances = { + let mut lr_distances = l_distances + .into_par_iter() + .map(|(a, d)| (a, d.as_f32())) + .zip(r_distances.into_par_iter().map(|(_, d)| d.as_f32())) + .map(|((a, al), ar)| { + let cos = ar.mul_add(-ar, lr.mul_add(lr, al.powi(2))) / (2.0 * al * lr); + (a, al * cos) + }) + .collect::>(); + lr_distances.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + lr_distances + .into_par_iter() + .map(|(a, d)| (a, T::from(d))) + .collect::>() + }; + + // Half of the items will be assigned to the left child and the + // other half to the right child. + let mid = if lr_distances.len() % 2 == 0 { + lr_distances.len() / 2 + } else { + 1 + lr_distances.len() / 2 + }; + let (ls, rs) = lr_distances.split_at(mid); + let l_stack = core::iter::once((l, T::ZERO)) + .chain(ls.iter().copied()) + .collect::>(); + let r_stack = core::iter::once((r, T::ZERO)) + .chain(rs.iter().copied()) + .collect::>(); + + vec![l_stack, r_stack] } else { - num_per_child + 1 + // If the metric does not obey the triangle inequality, we just sort + // the items by their distance to `l`. + let l_distances = { + let mut l_distances = l_distances; + l_distances.sort_by(|(_, a), (_, b)| a.total_cmp(b)); + l_distances + }; + + // Half of the items will be assigned to the left child and the + // other half to the right child. + let (ls, rs) = l_distances.split_at(l_distances.len() / 2); + let l_stack = core::iter::once((l, T::ZERO)) + .chain(ls.iter().copied()) + .collect::>(); + let r_stack = core::iter::once((r, T::ZERO)) + .chain(rs.iter().map(|&(a, d)| (a, lr - d))) + .collect::>(); + + vec![l_stack, r_stack] }; - let child_sizes = - core::iter::once(last_child_size).chain(core::iter::repeat(num_per_child).take(extrema.len() - 1)); - - // Initialize the child stacks with the extrema - let mut child_stacks = extrema.iter().map(|&e| vec![(e, U::ZERO)]).collect::>(); - for (child_stack, s) in child_stacks.iter_mut().zip(child_sizes) { - // Calculate the distances to the instances from the extremum - let mut distances = ParDataset::par_one_to_many(data, child_stack[0].0, &instances); - distances.par_sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(core::cmp::Ordering::Less)); - - // Remove the closest instances from the distances and add them to the child stack - child_stack.extend(distances.split_off(instances.len() - s)); - // Update the instances for the next child - instances = distances.into_iter().map(|(i, _)| i).collect(); - } - - // Unzip the child stacks into the indices and calculate the extent of each child child_stacks .into_par_iter() .map(|stack| { let (indices, distances) = stack.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); let extent = distances .into_iter() - .max_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)) + .max_by(Number::total_cmp) .unwrap_or_else(|| unreachable!("Cannot find the maximum distance")); (indices, extent) }) .unzip() } } - -#[cfg(test)] -mod tests { - use crate::{partition::ParPartition, Cluster, Dataset, FlatVec, Metric, Partition}; - - use super::BalancedBall; - - type F = FlatVec, i32, usize>; - type B = BalancedBall, i32, F>; - - fn gen_tiny_data() -> Result, i32, usize>, String> { - let instances = vec![vec![1, 2], vec![3, 4], vec![5, 6], vec![7, 8], vec![11, 12]]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - FlatVec::new_array(instances.clone(), metric) - } - - #[test] - fn new() -> Result<(), String> { - let data = gen_tiny_data()?; - - let indices = (0..data.cardinality()).collect::>(); - let seed = Some(42); - let root = B::new(&data, &indices, 0, seed); - let arg_r = root.arg_radial(); - - assert_eq!(arg_r, data.cardinality() - 1); - assert_eq!(root.depth(), 0); - assert_eq!(root.cardinality(), 5); - assert_eq!(root.arg_center(), 2); - assert_eq!(root.radius(), 12); - assert_eq!(root.arg_radial(), arg_r); - assert!(root.children().is_empty()); - assert_eq!(root.indices().collect::>(), indices); - - let root = B::par_new(&data, &indices, 0, seed); - let arg_r = root.arg_radial(); - - assert_eq!(arg_r, data.cardinality() - 1); - assert_eq!(root.depth(), 0); - assert_eq!(root.cardinality(), 5); - assert_eq!(root.arg_center(), 2); - assert_eq!(root.radius(), 12); - assert_eq!(root.arg_radial(), arg_r); - assert!(root.children().is_empty()); - assert_eq!(root.indices().collect::>(), indices); - - Ok(()) - } - - fn check_partition(root: &B) -> bool { - let indices = root.indices().collect::>(); - - assert!(!root.children().is_empty()); - assert_eq!(indices, &[0, 2, 1, 4, 3]); - - let children = root.child_clusters().collect::>(); - assert_eq!(children.len(), 2); - for &c in &children { - assert_eq!(c.depth(), 1); - assert!(c.children().is_empty()); - } - - let (left, right) = (children[0], children[1]); - - assert_eq!(left.cardinality(), 3); - assert_eq!(left.arg_center(), 1); - assert_eq!(left.radius(), 4); - assert!([0, 2].contains(&left.arg_radial())); - - assert_eq!(right.cardinality(), 2); - assert_eq!(right.radius(), 8); - assert!([3, 4].contains(&right.arg_center())); - assert!([3, 4].contains(&right.arg_radial())); - - true - } - - #[test] - fn tree() -> Result<(), String> { - let data = gen_tiny_data()?; - - let seed = Some(42); - let criteria = |c: &B| c.depth() < 1; - - let root = B::new_tree(&data, &criteria, seed); - assert_eq!(root.indices().count(), data.cardinality(), "{root:?}"); - assert!(check_partition(&root)); - - let root = B::par_new_tree(&data, &criteria, seed); - assert_eq!(root.indices().count(), data.cardinality(), "{root:?}"); - assert!(check_partition(&root)); - - Ok(()) - } -} diff --git a/crates/abd-clam/src/core/cluster/ball.rs b/crates/abd-clam/src/core/cluster/ball.rs index f5aacb9f8..349ac4d90 100644 --- a/crates/abd-clam/src/core/cluster/ball.rs +++ b/crates/abd-clam/src/core/cluster/ball.rs @@ -1,94 +1,88 @@ //! The most basic representation of a `Cluster` is a metric-`Ball`. -use core::fmt::Debug; - -use rayon::prelude::*; -use std::{hash::Hash, marker::PhantomData}; +use core::{ + cmp::Ordering, + hash::{Hash, Hasher}, +}; use distances::Number; -use serde::{Deserialize, Serialize}; - -use crate::{dataset::ParDataset, utils, Dataset}; +use rayon::prelude::*; -use super::{ - partition::{ParPartition, Partition}, - BalancedBall, Cluster, ParCluster, LFD, +use crate::{ + core::{dataset::ParDataset, metric::ParMetric, Dataset, Metric}, + utils, }; -/// A metric-`Ball` is a collection of instances that are within a certain -/// distance of a center. -#[derive(Clone, Serialize, Deserialize)] -pub struct Ball> { +use super::{partition::ParPartition, Cluster, ParCluster, Partition, LFD}; + +/// A metric-`Ball` is a collection of items that are within a certain distance +/// of a center. +/// +/// # Example +/// +/// ```rust +/// use abd_clam::{ +/// cluster::{Partition, ParPartition}, +/// metric::AbsoluteDifference, +/// Ball, Cluster, Dataset, FlatVec +/// }; +/// +/// let items = (0..=100).collect::>(); +/// let data = FlatVec::new(items).unwrap(); +/// let metric = AbsoluteDifference; +/// +/// // We will create a `Ball` with all the items in the `data`. +/// let indices = data.indices().collect::>(); +/// let ball = Ball::new(&data, &metric, &indices, 0, None).unwrap(); +/// +/// assert_eq!(ball.depth(), 0); +/// assert_eq!(ball.cardinality(), 101); +/// assert_eq!(ball.indices(), indices); +/// assert_eq!(ball.arg_center(), 50); +/// assert_eq!(ball.radius(), 50); +/// assert!([0, 100].contains(&ball.arg_radial())); +/// assert!((ball.lfd() - 1.0).abs() < 0.1); +/// +/// assert!(ball.is_leaf()); +/// +/// // We will now create a tree of `Ball`s with leaves being singletons. +/// let partition_criteria = |ball: &Ball<_>| ball.cardinality() > 1; +/// let root = Ball::new_tree(&data, &metric, &partition_criteria, None); +/// assert!(!root.is_leaf()); +/// +/// // We can also use the equivalent parallelized methods to create the tree. +/// let root = Ball::par_new_tree(&data, &metric, &partition_criteria, None); +/// assert!(!root.is_leaf()); +/// ``` +#[derive(Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Deserialize, serde::Serialize) +)] +#[cfg_attr(feature = "disk-io", bitcode(recursive))] +pub struct Ball { /// Parameters used for creating the `Ball`. depth: usize, - /// The number of instances in the `Ball`. + /// The number of items in the `Ball`. cardinality: usize, /// The radius of the `Ball`. - radius: U, + radius: T, /// The local fractal dimension of the `Ball`. lfd: f32, - /// The index of the center instance. + /// The index of the center item. arg_center: usize, - /// The index of the instance that is the furthest from the center. + /// The index of the item that is the furthest from the center. arg_radial: usize, - /// The indices of the instances in the `Ball`. - pub(crate) indices: Vec, + /// The indices of the items in the `Ball`. + indices: Vec, + /// The extents of the `Ball`. + extents: Vec<(usize, T)>, /// The children of the `Ball`. - children: Vec<(usize, U, Box)>, - /// Phantom data to satisfy the compiler. - _id: PhantomData<(I, D)>, -} - -impl> Ball { - /// Creates a new `Ball` from a `BalancedBall`. - pub fn from_balanced_ball(balanced_ball: BalancedBall) -> Self { - let mut ball = balanced_ball.ball; - let children = balanced_ball - .children - .into_iter() - .map(|(e, d, b)| (e, d, Box::new(Self::from_balanced_ball(*b)))) - .collect(); - ball.children = children; - ball - } - - /// Changes the associated `Dataset` type. - pub fn with_dataset_type>(self) -> Ball { - let children = self - .children - .into_iter() - .map(|(e, d, b)| (e, d, Box::new(b.with_dataset_type()))) - .collect(); - Ball { - depth: self.depth, - cardinality: self.cardinality, - radius: self.radius, - lfd: self.lfd, - arg_center: self.arg_center, - arg_radial: self.arg_radial, - indices: self.indices, - children, - _id: PhantomData, - } - } -} - -impl> Ball { - /// Creates a new `Ball` from a `BalancedBall`. - pub fn par_from_balanced_ball(balanced_ball: BalancedBall) -> Self { - let mut ball = balanced_ball.ball; - let children = balanced_ball - .children - .into_par_iter() - .map(|(e, d, b)| (e, d, Box::new(Self::from_balanced_ball(*b)))) - .collect(); - ball.children = children; - ball - } + children: Vec>, } -impl> Debug for Ball { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for Ball { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("Ball") .field("depth", &self.depth) .field("cardinality", &self.cardinality) @@ -97,27 +91,28 @@ impl> Debug for Ball { .field("arg_center", &self.arg_center) .field("arg_radial", &self.arg_radial) .field("indices", &self.indices) + .field("extents", &self.extents) .field("children", &!self.children.is_empty()) .finish() } } -impl> PartialEq for Ball { +impl PartialEq for Ball { fn eq(&self, other: &Self) -> bool { self.depth == other.depth && self.cardinality == other.cardinality && self.indices == other.indices } } -impl> Eq for Ball {} +impl Eq for Ball {} -impl> PartialOrd for Ball { - fn partial_cmp(&self, other: &Self) -> Option { +impl PartialOrd for Ball { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl> Ord for Ball { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { +impl Ord for Ball { + fn cmp(&self, other: &Self) -> Ordering { self.depth .cmp(&other.depth) .then_with(|| self.cardinality.cmp(&other.cardinality)) @@ -125,14 +120,14 @@ impl> Ord for Ball { } } -impl> Hash for Ball { - fn hash(&self, state: &mut H) { +impl Hash for Ball { + fn hash(&self, state: &mut H) { // We hash the `indices` field self.indices.hash(state); } } -impl> Cluster for Ball { +impl Cluster for Ball { fn depth(&self) -> usize { self.depth } @@ -149,7 +144,7 @@ impl> Cluster for Ball { self.arg_center = arg_center; } - fn radius(&self) -> U { + fn radius(&self) -> T { self.radius } @@ -165,88 +160,93 @@ impl> Cluster for Ball { self.lfd } - fn indices(&self) -> impl Iterator + '_ { - self.indices.iter().copied() + fn contains(&self, index: usize) -> bool { + self.indices.contains(&index) } - fn set_indices(&mut self, indices: Vec) { - self.indices = indices; + fn indices(&self) -> Vec { + self.indices.clone() } - fn children(&self) -> &[(usize, U, Box)] { - self.children.as_slice() + fn set_indices(&mut self, indices: &[usize]) { + self.indices = indices.to_vec(); } - fn children_mut(&mut self) -> &mut [(usize, U, Box)] { - self.children.as_mut_slice() + fn extents(&self) -> &[(usize, T)] { + &self.extents } - fn set_children(&mut self, children: Vec<(usize, U, Box)>) { - self.children = children; + fn extents_mut(&mut self) -> &mut [(usize, T)] { + &mut self.extents } - fn take_children(&mut self) -> Vec<(usize, U, Box)> { - core::mem::take(&mut self.children) + fn add_extent(&mut self, index: usize, extent: T) { + self.extents.push((index, extent)); + } + + fn take_extents(&mut self) -> Vec<(usize, T)> { + core::mem::take(&mut self.extents) + } + + fn children(&self) -> Vec<&Self> { + self.children.iter().map(AsRef::as_ref).collect() + } + + fn children_mut(&mut self) -> Vec<&mut Self> { + self.children.iter_mut().map(AsMut::as_mut).collect() + } + + fn set_children(&mut self, children: Vec>) { + self.children = children; } - fn distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)> { - data.query_to_many(query, &self.indices) + fn take_children(&mut self) -> Vec> { + core::mem::take(&mut self.children) } fn is_descendant_of(&self, other: &Self) -> bool { - let indices = other.indices().collect::>(); - self.indices().all(|i| indices.contains(&i)) + self.indices.iter().all(|i| other.indices.contains(i)) } } -impl> ParCluster for Ball { - fn par_distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)> { - data.par_query_to_many(query, &self.indices().collect::>()) +impl ParCluster for Ball { + fn par_indices(&self) -> impl ParallelIterator { + self.indices.par_iter().copied() } } -impl> Partition for Ball { - fn new(data: &D, indices: &[usize], depth: usize, seed: Option) -> Self - where - Self: Sized, - { +impl Partition for Ball { + fn new, M: Metric>( + data: &D, + metric: &M, + indices: &[usize], + depth: usize, + seed: Option, + ) -> Result { if indices.is_empty() { - unreachable!("Cannot create a Ball with no instances") + return Err("Cannot create a Ball with no items".to_string()); } let cardinality = indices.len(); - let samples = if cardinality < 100 { indices.to_vec() } else { - #[allow(clippy::cast_possible_truncation)] - let n = if cardinality < 10_100 { - // We use the square root of the cardinality as the number of samples - (cardinality - 100).as_f64().sqrt().as_u64() as usize - } else { - // We use the logarithm of the cardinality as the number of samples - #[allow(clippy::cast_possible_truncation)] - let n = (cardinality - 10_100).as_f64().log2().as_u64() as usize; - n + 100 - }; - - let n = n + 100; - Dataset::choose_unique(data, indices, n, seed) + let num_samples = utils::num_samples(cardinality, 100, 10_000); + data.choose_unique(indices, num_samples, seed, metric) }; - let arg_center = Dataset::median(data, &samples); + let arg_center = data.median(&samples, metric); - let distances = Dataset::one_to_many(data, arg_center, indices); + let distances = data.one_to_many(arg_center, indices, metric).collect::>(); let &(arg_radial, radius) = distances .iter() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)) + .max_by(|(_, a), (_, b)| a.total_cmp(b)) .unwrap_or_else(|| unreachable!("Cannot find the maximum distance")); let distances = distances.into_iter().map(|(_, d)| d).collect::>(); - let lfd_scale = radius.half(); - let lfd = LFD::from_radial_distances(&distances, lfd_scale); + let lfd = LFD::from_radial_distances(&distances, radius.half()); - Self { + Ok(Self { depth, cardinality, radius, @@ -254,64 +254,55 @@ impl> Partition for Ball { arg_center, arg_radial, indices: indices.to_vec(), + extents: vec![(arg_center, radius)], children: Vec::new(), - _id: PhantomData, - } + }) } - fn find_extrema(&self, data: &D) -> Vec { - let l_distances = Dataset::one_to_many(data, self.arg_radial, &self.indices); - - let &(arg_l, _) = l_distances - .iter() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)) + fn find_extrema, M: Metric>(&mut self, data: &D, metric: &M) -> Vec { + let (arg_l, d) = data + .one_to_many(self.arg_radial, &self.indices, metric) + .max_by(|(_, a), (_, b)| a.total_cmp(b)) .unwrap_or_else(|| unreachable!("Cannot find the maximum distance")); + self.add_extent(self.arg_radial, d); + vec![arg_l, self.arg_radial] } } -impl> ParPartition for Ball { - fn par_new(data: &D, indices: &[usize], depth: usize, seed: Option) -> Self - where - Self: Sized, - { +impl ParPartition for Ball { + fn par_new, M: ParMetric>( + data: &D, + metric: &M, + indices: &[usize], + depth: usize, + seed: Option, + ) -> Result { if indices.is_empty() { - unreachable!("Cannot create a Ball with no instances") + return Err("Cannot create a Ball with no items".to_string()); } let cardinality = indices.len(); - let samples = if cardinality < 100 { indices.to_vec() } else { - #[allow(clippy::cast_possible_truncation)] - let n = if cardinality < 10_100 { - // We use the square root of the cardinality as the number of samples - (cardinality - 100).as_f64().sqrt().as_u64() as usize - } else { - // We use the logarithm of the cardinality as the number of samples - #[allow(clippy::cast_possible_truncation)] - let n = (cardinality - 10_100).as_f64().log2().as_u64() as usize; - n + 100 - }; - - let n = n + 100; - ParDataset::par_choose_unique(data, indices, n, seed) + let num_samples = utils::num_samples(cardinality, 100, 10_000); + data.choose_unique(indices, num_samples, seed, metric) }; - let arg_center = ParDataset::par_median(data, &samples); + let arg_center = data.par_median(&samples, metric); - let distances = ParDataset::par_one_to_many(data, arg_center, indices); + let distances = data.par_one_to_many(arg_center, indices, metric).collect::>(); let &(arg_radial, radius) = distances .iter() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)) + .max_by(|(_, a), (_, b)| a.total_cmp(b)) .unwrap_or_else(|| unreachable!("Cannot find the maximum distance")); let distances = distances.into_iter().map(|(_, d)| d).collect::>(); - let lfd = utils::compute_lfd(radius, &distances); + let lfd = LFD::from_radial_distances(&distances, radius.half()); - Self { + Ok(Self { depth, cardinality, radius, @@ -319,25 +310,29 @@ impl> ParPartition for B arg_center, arg_radial, indices: indices.to_vec(), + extents: vec![(arg_center, radius)], children: Vec::new(), - _id: PhantomData, - } + }) } - fn par_find_extrema(&self, data: &D) -> Vec { - let l_distances = ParDataset::par_one_to_many(data, self.arg_radial, &self.indices); - - let &(arg_l, _) = l_distances - .iter() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)) + fn par_find_extrema, M: ParMetric>( + &mut self, + data: &D, + metric: &M, + ) -> Vec { + let (arg_l, d) = data + .par_one_to_many(self.arg_radial, &self.indices, metric) + .max_by(|(_, a), (_, b)| a.total_cmp(b)) .unwrap_or_else(|| unreachable!("Cannot find the maximum distance")); + self.add_extent(self.arg_radial, d); + vec![arg_l, self.arg_radial] } } -#[cfg(feature = "csv")] -impl> super::WriteCsv for Ball { +#[cfg(feature = "disk-io")] +impl super::Csv for Ball { fn header(&self) -> Vec { vec![ "depth".to_string(), @@ -363,197 +358,11 @@ impl> super::WriteCsv for Ball } } -#[cfg(test)] -mod tests { - use distances::number::{Addition, Multiplication}; - - use crate::{partition::ParPartition, Cluster, Dataset, FlatVec, Metric, Partition}; - - use super::Ball; - - type F = FlatVec, i32, usize>; - type B = Ball, i32, F>; - - fn gen_tiny_data() -> Result, i32, usize>, String> { - let instances = vec![vec![1, 2], vec![3, 4], vec![5, 6], vec![7, 8], vec![11, 12]]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - FlatVec::new_array(instances.clone(), metric) - } - - fn gen_pathological_line() -> FlatVec { - let min_delta = 1e-12; - let mut delta = min_delta; - let mut line = vec![0_f64]; - - while line.len() < 900 { - let last = *line.last().unwrap(); - line.push(last + delta); - delta *= 2.0; - delta += min_delta; - } - - let distance_fn = |x: &f64, y: &f64| x.abs_diff(*y); - let metric = Metric::new(distance_fn, false); - FlatVec::new(line, metric).unwrap() - } - - #[test] - fn new() -> Result<(), String> { - let data = gen_tiny_data()?; - - let indices = (0..data.cardinality()).collect::>(); - let seed = Some(42); - let root = Ball::new(&data, &indices, 0, seed); - let arg_r = root.arg_radial(); +#[cfg(feature = "disk-io")] +impl super::ParCsv for Ball {} - assert_eq!(arg_r, data.cardinality() - 1); - assert_eq!(root.depth(), 0); - assert_eq!(root.cardinality(), 5); - assert_eq!(root.arg_center(), 2); - assert_eq!(root.radius(), 12); - assert_eq!(root.arg_radial(), arg_r); - assert!(root.children().is_empty()); - assert_eq!(root.indices().collect::>(), indices); +#[cfg(feature = "disk-io")] +impl super::ClusterIO for Ball {} - let root = Ball::par_new(&data, &indices, 0, seed); - let arg_r = root.arg_radial(); - - assert_eq!(arg_r, data.cardinality() - 1); - assert_eq!(root.depth(), 0); - assert_eq!(root.cardinality(), 5); - assert_eq!(root.arg_center(), 2); - assert_eq!(root.radius(), 12); - assert_eq!(root.arg_radial(), arg_r); - assert!(root.children().is_empty()); - assert_eq!(root.indices().collect::>(), indices); - - Ok(()) - } - - fn check_partition(root: &B) -> bool { - let indices = root.indices().collect::>(); - - assert!(!root.children().is_empty()); - assert_eq!(indices, &[0, 1, 2, 4, 3]); - - let children = root.child_clusters().collect::>(); - assert_eq!(children.len(), 2); - for &c in &children { - assert_eq!(c.depth(), 1); - assert!(c.children().is_empty()); - } - - let (left, right) = (children[0], children[1]); - - assert_eq!(left.cardinality(), 3); - assert_eq!(left.arg_center(), 1); - assert_eq!(left.radius(), 4); - assert!([0, 2].contains(&left.arg_radial())); - - assert_eq!(right.cardinality(), 2); - assert_eq!(right.radius(), 8); - assert!([3, 4].contains(&right.arg_center())); - assert!([3, 4].contains(&right.arg_radial())); - - true - } - - #[test] - fn tree() -> Result<(), String> { - let data = gen_tiny_data()?; - - let seed = Some(42); - let criteria = |c: &B| c.depth() < 1; - - let root = Ball::new_tree(&data, &criteria, seed); - assert_eq!(root.indices().count(), data.cardinality()); - assert!(check_partition(&root)); - - let root = Ball::par_new_tree(&data, &criteria, seed); - assert_eq!(root.indices().count(), data.cardinality()); - assert!(check_partition(&root)); - - Ok(()) - } - - #[test] - fn partition_further() -> Result<(), String> { - let data = gen_tiny_data()?; - - let seed = Some(42); - let criteria_one = |c: &B| c.depth() < 1; - let criteria_two = |c: &B| c.depth() < 2; - - let mut root = Ball::new_tree(&data, &criteria_one, seed); - for leaf in root.leaves() { - assert_eq!(leaf.depth(), 1); - } - root.partition_further(&data, &criteria_two, seed); - for leaf in root.leaves() { - assert_eq!(leaf.depth(), 2); - } - - let mut root = Ball::par_new_tree(&data, &criteria_one, seed); - for leaf in root.leaves() { - assert_eq!(leaf.depth(), 1); - } - root.par_partition_further(&data, &criteria_two, seed); - for leaf in root.leaves() { - assert_eq!(leaf.depth(), 2); - } - - Ok(()) - } - - #[test] - fn tree_iterative() { - let data = gen_pathological_line(); - - let seed = Some(42); - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - - let indices = (0..data.cardinality()).collect::>(); - let mut root = Ball::new(&data, &indices, 0, seed); - - let mut intermediate_depth = root.max_recursion_depth(); - let intermediate_criteria = |c: &Ball<_, _, _>| c.depth() < intermediate_depth && criteria(c); - root.partition(&data, &intermediate_criteria, seed); - - while root.leaves().into_iter().any(|l| !l.is_singleton()) { - intermediate_depth += root.max_recursion_depth(); - let intermediate_criteria = |c: &Ball<_, _, _>| c.depth() < intermediate_depth && criteria(c); - root.partition_further(&data, &intermediate_criteria, seed); - } - - assert!(!root.is_leaf()); - } - - #[test] - fn trim_and_graft() -> Result<(), String> { - let line = (0..1024).collect(); - let distance_fn = |x: &u32, y: &u32| x.abs_diff(*y); - let metric = Metric::new(distance_fn, false); - let data = FlatVec::new(line, metric)?; - - let seed = Some(42); - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; - let root = Ball::new_tree(&data, &criteria, seed); - - let target_depth = 4; - let mut grafted_root = root.clone(); - let children = grafted_root.trim_at_depth(target_depth); - - let leaves = grafted_root.leaves(); - assert_eq!(leaves.len(), 2.powi(target_depth as i32)); - assert_eq!(leaves.len(), children.len()); - - grafted_root.graft_at_depth(target_depth, children); - assert_eq!(grafted_root, root); - for (l, c) in root.subtree().into_iter().zip(grafted_root.subtree()) { - assert_eq!(l, c); - } - - Ok(()) - } -} +#[cfg(feature = "disk-io")] +impl super::ParClusterIO for Ball {} diff --git a/crates/abd-clam/src/core/cluster/csv.rs b/crates/abd-clam/src/core/cluster/csv.rs deleted file mode 100644 index 3f7156675..000000000 --- a/crates/abd-clam/src/core/cluster/csv.rs +++ /dev/null @@ -1,146 +0,0 @@ -//! Writing `Cluster` trees to CSV files. - -use std::io::Write; - -use distances::Number; - -use crate::Dataset; - -use super::Cluster; - -/// Write a `Cluster` to a CSV file. -#[allow(clippy::module_name_repetitions)] -pub trait WriteCsv: Cluster -where - U: Number, - D: Dataset, -{ - /// Returns the names of the columns in the CSV file. - fn header(&self) -> Vec; - - /// Returns a row, corresponding to the `Cluster`, for the CSV file. - fn row(&self) -> Vec; - - /// Write to a CSV file, all the clusters in the tree. - /// - /// # Errors - /// - /// - If the file cannot be created. - /// - If the file cannot be written to. - /// - If the header cannot be written to the file. - /// - If any row cannot be written to the file. - fn write_to_csv>(&self, path: &P) -> Result<(), String> { - let line = |items: Vec| { - let mut line = items.join(","); - line.push('\n'); - line - }; - - // Create the file and write the header. - let mut file = std::fs::File::create(path).map_err(|e| e.to_string())?; - file.write_all(line(self.header()).as_bytes()) - .map_err(|e| e.to_string())?; - - // Write each row to the file. - for row in self.subtree().into_iter().map(Self::row).map(line) { - file.write_all(row.as_bytes()).map_err(|e| e.to_string())?; - } - - Ok(()) - } -} - -// /// Returns the subtree, with unique integers to reference the parents and children of each `Cluster` in a `HashMap`. -// /// -// /// The Vec contains tuples of: -// /// -// /// - The `Cluster` itself. -// /// - The index of the `Cluster` in the Vec. -// /// - The position of the `Cluster` among its siblings. -// /// - A Vec of tuples of: -// /// - The index of the parent `Cluster` in the Vec. -// /// - A tuple of: -// /// - The index of the parent `Cluster` in the Vec. -// /// - The extent of the child. -// #[allow(clippy::type_complexity)] -// fn take_subtree(mut self) -> Vec<(Self, usize, usize, Vec<(usize, (usize, U))>)> { -// let children = self.take_children(); -// let mut clusters = vec![(self, 0, 0, vec![])]; - -// for (e, d, children) in children.into_iter().map(|(e, d, c)| (e, d, c.take_subtree())) { -// let offset = clusters.len(); - -// for (ci, (child, parent_index, _, children_indices)) in children.into_iter().enumerate() { -// let parent_index = parent_index + offset; -// let children_indices = children_indices.into_iter().map(|(pi, ed)| (pi + offset, ed)).collect(); -// clusters.push((child, parent_index, ci, children_indices)); -// } - -// clusters[0].3.push((offset, (e, d))); -// } - -// clusters -// } - -// /// Returns the subtree as a list of `Cluster`s, with the indices required -// /// to go from a parent to a child and vice versa. -// /// -// /// The Vec contains tuples of: -// /// -// /// - The `Cluster` itself. -// /// - The position of the `Cluster` among its siblings. -// /// - A Vec of tuples of: -// /// - The index of the parent `Cluster` in the Vec. -// /// - A tuple of: -// /// - The index of the parent `Cluster` in the Vec. -// /// - The extent of the child. -// #[allow(clippy::type_complexity)] -// fn unstack_tree(self) -> Vec<(Self, usize, Vec<(usize, (usize, U))>)> { -// let mut subtree = self.take_subtree(); -// subtree.sort_by_key(|(_, i, _, _)| *i); -// subtree -// .into_iter() -// .map(|(c, _, ci, children)| (c, ci, children)) -// .collect() -// } - -// #[test] -// fn numbered_subtree() { -// let data = (0..1024).collect::>(); -// let distance_fn = |x: &u32, y: &u32| x.abs_diff(*y); -// let metric = Metric::new(distance_fn, false); -// let data = FlatVec::new(data, metric).unwrap(); - -// let seed = Some(42); -// let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; -// let root = Ball::new_tree(&data, &criteria, seed); - -// let mut numbered_subtree = root.clone().take_subtree(); -// let indices = numbered_subtree.iter().map(|(_, i, _, _)| *i).collect::>(); -// assert_eq!(indices.len(), numbered_subtree.len()); - -// for i in 0..numbered_subtree.len() { -// assert!(indices.contains(&i)); -// } - -// numbered_subtree.sort_by(|(_, i, _, _), (_, j, _, _)| i.cmp(j)); -// let numbered_subtree = numbered_subtree -// .into_iter() -// .map(|(a, _, _, b)| (a, b)) -// .collect::>(); - -// for (ball, child_indices) in &numbered_subtree { -// for &(i, _) in child_indices { -// let (child, _) = &numbered_subtree[i]; -// assert!(child.is_descendant_of(ball), "{ball:?} is not parent of {child:?}"); -// } -// } - -// // let root_list = root.clone().as_indexed_list(); -// // let re_root = Ball::from_indexed_list(root_list); -// // assert_eq!(root, re_root); - -// // for (l, r) in root.subtree().into_iter().zip(re_root.subtree()) { -// // assert_eq!(l, r); -// // } -// } diff --git a/crates/abd-clam/src/core/cluster/io.rs b/crates/abd-clam/src/core/cluster/io.rs new file mode 100644 index 000000000..90b1039ce --- /dev/null +++ b/crates/abd-clam/src/core/cluster/io.rs @@ -0,0 +1,147 @@ +//! Writing `Cluster` trees to CSV files. + +use std::io::Write; + +use distances::Number; +use rayon::prelude::*; + +use super::{Cluster, ParCluster}; + +#[cfg(feature = "disk-io")] +/// Write a tree to a CSV file. +pub trait Csv: Cluster { + /// Returns the names of the columns in the CSV file. + fn header(&self) -> Vec; + + /// Returns a row, corresponding to the `Cluster`, for the CSV file. + fn row(&self) -> Vec; + + /// Write to a CSV file, all the clusters in the tree. + /// + /// # Errors + /// + /// - If the file cannot be created. + /// - If the file cannot be written to. + /// - If the header cannot be written to the file. + /// - If any row cannot be written to the file. + fn write_to_csv>(&self, path: &P) -> Result<(), String> { + let line = |items: Vec| { + let mut line = items.join(","); + line.push('\n'); + line + }; + + // Create the file and write the header. + let mut file = std::fs::File::create(path).map_err(|e| e.to_string())?; + file.write_all(line(self.header()).as_bytes()) + .map_err(|e| e.to_string())?; + + // Write each row to the file. + for row in self.subtree().into_iter().map(Self::row).map(line) { + file.write_all(row.as_bytes()).map_err(|e| e.to_string())?; + } + + Ok(()) + } +} + +#[cfg(feature = "disk-io")] +/// Parallel version of [`Csv`](crate::core::cluster::io::Csv). +pub trait ParCsv: Csv + ParCluster { + /// Parallel version of [`Csv::write_to_csv`](crate::core::cluster::Csv::write_to_csv). + /// + /// # Errors + /// + /// See [`Csv::write_to_csv`](crate::core::cluster::Csv::write_to_csv). + fn par_write_to_csv>(&self, path: &P) -> Result<(), String> { + let line = |items: Vec| { + let mut line = items.join(","); + line.push('\n'); + line + }; + + // Create the file and write the header. + let mut file = std::fs::File::create(path).map_err(|e| e.to_string())?; + file.write_all(line(self.header()).as_bytes()) + .map_err(|e| e.to_string())?; + + let rows = self + .subtree() + .into_par_iter() + .map(Self::row) + .map(line) + .collect::>(); + + // Write each row to the file. + for row in rows { + file.write_all(row.as_bytes()).map_err(|e| e.to_string())?; + } + + Ok(()) + } +} + +#[cfg(feature = "disk-io")] +/// Reading and writing `Cluster` trees to disk using `bitcode`. +pub trait ClusterIO: Cluster { + /// Writes the `Cluster` to disk in binary format using `bitcode`. + /// + /// # Errors + /// + /// - If the cluster cannot be encoded. + /// - If the file cannot be written. + fn write_to>(&self, path: &P) -> Result<(), String> + where + Self: bitcode::Encode, + { + let bytes = bitcode::encode(self).map_err(|e| e.to_string())?; + std::fs::write(path, bytes).map_err(|e| e.to_string()) + } + + /// Reads the `Cluster` from disk in binary format using `bitcode`. + /// + /// # Errors + /// + /// - If the file cannot be read. + /// - If the cluster cannot be decoded. + fn read_from>(path: &P) -> Result + where + Self: bitcode::Decode, + { + let bytes = std::fs::read(path).map_err(|e| e.to_string())?; + bitcode::decode(&bytes).map_err(|e| e.to_string()) + } +} + +#[cfg(feature = "disk-io")] +/// Parallel version of [`ClusterIO`](crate::core::cluster::io::ClusterIO). +pub trait ParClusterIO: ParCluster + ClusterIO { + /// Parallel version of [`ClusterIO::write_to`](crate::core::cluster::ClusterIO::write_to). + /// + /// The default implementation offers no parallelism. + /// + /// # Errors + /// + /// See [`ClusterIO::write_to`](crate::core::cluster::ClusterIO::write_to). + fn par_write_to>(&self, path: &P) -> Result<(), String> + where + Self: bitcode::Encode, + { + self.write_to(path) + } + + /// Parallel version of [`ClusterIO::read_from`](crate::core::cluster::ClusterIO::read_from). + /// + /// The default implementation offers no parallelism. + /// + /// # Errors + /// + /// See [`ClusterIO::read_from`](crate::core::cluster::ClusterIO::read_from). + fn par_read_from>(path: &P) -> Result + where + Self: bitcode::Decode, + { + let bytes = std::fs::read(path).map_err(|e| e.to_string())?; + bitcode::decode(&bytes).map_err(|e| e.to_string()) + } +} diff --git a/crates/abd-clam/src/core/cluster/lfd.rs b/crates/abd-clam/src/core/cluster/lfd.rs index 85f183319..48abb32ee 100644 --- a/crates/abd-clam/src/core/cluster/lfd.rs +++ b/crates/abd-clam/src/core/cluster/lfd.rs @@ -2,7 +2,7 @@ use distances::{number::Addition, Number}; -/// Local Fractal Dimension. +/// Helpers for dealing with Local Fractal Dimension (LFD) calculations. /// /// The LFD of a `Cluster` is a measure of the fractal dimension of the /// `Cluster` at a local scale. It is calculated as the logarithm of the ratio @@ -36,9 +36,9 @@ impl LFD { /// Calculate LFD from distances and a given scale. /// - /// This is calculated as the ratio of the logarithms of the number of points - /// in the cluster and the number of points within the given scale to the - /// logarithm of the two scales. + /// This is calculated as the ratio of the logarithms of the number of + /// points in the cluster and the number of points within the given scale to + /// the logarithm of the two scales. /// /// # Arguments /// diff --git a/crates/abd-clam/src/core/cluster/mod.rs b/crates/abd-clam/src/core/cluster/mod.rs index e880113f0..b7109adc5 100644 --- a/crates/abd-clam/src/core/cluster/mod.rs +++ b/crates/abd-clam/src/core/cluster/mod.rs @@ -1,132 +1,150 @@ -//! A `Cluster` is a collection of "similar" instances in a dataset. +//! A `Cluster` is a collection of "similar" items in a dataset. + +use distances::{number::Float, Number}; +use rayon::prelude::*; pub mod adapter; mod balanced_ball; mod ball; mod lfd; -pub mod partition; - -#[cfg(feature = "csv")] -mod csv; - -use distances::Number; - -use super::{dataset::ParDataset, Dataset, MetricSpace}; +mod partition; pub use balanced_ball::BalancedBall; pub use ball::Ball; pub use lfd::LFD; -pub use partition::Partition; +pub use partition::{ParPartition, Partition}; -#[cfg(feature = "csv")] -pub use csv::WriteCsv; +#[cfg(feature = "disk-io")] +mod io; -/// A `Cluster` is a collection of "similar" instances in a dataset. +#[cfg(feature = "disk-io")] +#[allow(clippy::module_name_repetitions)] +pub use io::{ClusterIO, Csv, ParClusterIO, ParCsv}; + +use super::{dataset::ParDataset, metric::ParMetric, Dataset, Metric}; + +/// A `Cluster` is a collection of "similar" items in a dataset. +/// +/// It represents a metric ball in a metric space. All items in a `Cluster` are +/// within a certain distance of a center. The `Cluster` may have children, +/// which are `Cluster`s of the same type. /// /// # Type Parameters /// -/// - `I`: The type of the instances in the dataset. -/// - `U`: The type of the distance values between instances. -/// - `D`: The type of the dataset. +/// - `T`: The type of the distance values between items. /// /// # Remarks /// /// A `Cluster` must have certain properties to be useful in CLAM. These are: /// /// - `depth`: The depth of the `Cluster` in the tree. -/// - `cardinality`: The number of instances in the `Cluster`. -/// - `indices`: The indices of the instances in the `Cluster`. -/// - `arg_center`: The index of the geometric median of the instances in the -/// `Cluster`. This may be computed exactly, using all instances in the -/// `Cluster`, or approximately, using a subset of the instances. -/// - `radius`: The distance from the center to the farthest instance in the +/// - `cardinality`: The number of items in the `Cluster`. +/// - `indices`: The indices into a dataset of the items in the `Cluster`. +/// - `arg_center`: The index of the geometric median of the items in the +/// `Cluster`. This may be computed exactly, using all items in the `Cluster`, +/// or approximately, using a subset of the items. +/// - `radius`: The distance from the center to the farthest item in the /// `Cluster`. -/// - `arg_radial`: The index of the instance that is farthest from the center. +/// - `arg_radial`: The index of the item that is farthest from the center. /// - `lfd`: The Local Fractional Dimension of the `Cluster`. /// /// A `Cluster` may have two or more children, which are `Cluster`s of the same /// type. The children should be stored as a tuple with: /// -/// - The index of the extremal instance in the `Cluster` that was used to +/// - The index of the extremal item in the `Cluster` that was used to /// create the child. -/// - The distance from that extremal instance to the farthest instance that was +/// - The distance from that extremal item to the farthest item that was /// assigned to the child. We refer to this as the "extent" of the child. /// - The child `Cluster`. -pub trait Cluster>: Ord + core::hash::Hash + Sized { - /// Returns the depth os the `Cluster` in the tree. +/// +/// # Examples +/// +/// See: +/// +/// - [`Ball`](crate::core::cluster::Ball) +/// - [`BalancedBall`](crate::core::cluster::BalancedBall) +/// - [`PermutedBall`](crate::cakes::PermutedBall) +/// - [`SquishyBall`](crate::pancakes::SquishyBall) +pub trait Cluster: PartialEq + Eq + PartialOrd + Ord + core::hash::Hash { + /// Returns the depth of the `Cluster` in the tree. fn depth(&self) -> usize; /// Returns the cardinality of the `Cluster`. fn cardinality(&self) -> usize; - /// Returns the index of the center instance in the `Cluster`. + /// Returns the index of the center item in the `Cluster`. fn arg_center(&self) -> usize; - /// Sets the index of the center instance in the `Cluster`. + /// Sets the index of the center item in the `Cluster`. /// - /// This is used to find the center instance after permutation. + /// This is used to find the center item after permutation. fn set_arg_center(&mut self, arg_center: usize); /// Returns the radius of the `Cluster`. - fn radius(&self) -> U; + fn radius(&self) -> T; - /// Returns the index of the radial instance in the `Cluster`. + /// Returns the index of the radial item in the `Cluster`. fn arg_radial(&self) -> usize; - /// Sets the index of the radial instance in the `Cluster`. + /// Sets the index of the radial item in the `Cluster`. /// - /// This is used to find the radial instance after permutation. + /// This is used to find the radial item after permutation. fn set_arg_radial(&mut self, arg_radial: usize); /// Returns the Local Fractional Dimension (LFD) of the `Cluster`. fn lfd(&self) -> f32; - /// Gets the indices of the instances in the `Cluster`. - fn indices(&self) -> impl Iterator + '_; + /// Returns whether this `Cluster` contains the given `index`ed point. + fn contains(&self, idx: usize) -> bool; + + /// Gets the indices of the items in the `Cluster`. + fn indices(&self) -> Vec; + + /// Sets the indices of the items in the `Cluster`. + fn set_indices(&mut self, indices: &[usize]); + + /// The `extents` of a cluster are pairs of an index of an item in the + /// cluster and the distance from that item to the farthest item in the + /// cluster. + fn extents(&self) -> &[(usize, T)]; - /// Sets the indices of the instances in the `Cluster`. - fn set_indices(&mut self, indices: Vec); + /// Returns the extents as a mutable slice. + fn extents_mut(&mut self) -> &mut [(usize, T)]; + + /// Adds an extent to the `Cluster`. + fn add_extent(&mut self, idx: usize, extent: T); + + /// Clears the extents of the `Cluster` and returns the old extents. + fn take_extents(&mut self) -> Vec<(usize, T)>; /// Returns the children of the `Cluster`. #[must_use] - fn children(&self) -> &[(usize, U, Box)]; + fn children(&self) -> Vec<&Self>; /// Returns the children of the `Cluster` as mutable references. #[must_use] - fn children_mut(&mut self) -> &mut [(usize, U, Box)]; + fn children_mut(&mut self) -> Vec<&mut Self>; /// Sets the children of the `Cluster`. - fn set_children(&mut self, children: Vec<(usize, U, Box)>); + fn set_children(&mut self, children: Vec>); /// Returns the owned children and sets the cluster's children to an empty vector. - fn take_children(&mut self) -> Vec<(usize, U, Box)>; - - /// Computes the distances from the `query` to all instances in the `Cluster`. - fn distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)>; + fn take_children(&mut self) -> Vec>; /// Returns whether the `Cluster` is a descendant of another `Cluster`. fn is_descendant_of(&self, other: &Self) -> bool; - /// Reads the `MAX_RECURSION_DEPTH` environment variable to determine the - /// stride for iterative partition and adaptation. - fn max_recursion_depth(&self) -> usize { - std::env::var("MAX_RECURSION_DEPTH") - .ok() - .and_then(|s| s.parse().ok()) - .unwrap_or(256) - } - /// Clears the indices stored with every cluster in the tree. fn clear_indices(&mut self) { if !self.is_leaf() { - self.child_clusters_mut().for_each(Self::clear_indices); + self.children_mut().into_iter().for_each(Self::clear_indices); } - self.set_indices(Vec::new()); + self.set_indices(&[]); } /// Trims the tree at the given depth. Returns the trimmed roots in the same /// order as the leaves of the trimmed tree at that depth. - fn trim_at_depth(&mut self, depth: usize) -> Vec)>> { + fn trim_at_depth(&mut self, depth: usize) -> Vec>> { let mut queue = vec![self]; let mut stack = Vec::new(); @@ -134,7 +152,7 @@ pub trait Cluster>: Ord + core::hash::Hash + Size if c.depth() == depth { stack.push(c); } else { - queue.extend(c.child_clusters_mut()); + queue.extend(c.children_mut()); } } @@ -142,7 +160,7 @@ pub trait Cluster>: Ord + core::hash::Hash + Size } /// Inverts the `trim_at_depth` method. - fn graft_at_depth(&mut self, depth: usize, trimmings: Vec)>>) { + fn graft_at_depth(&mut self, depth: usize, trimmings: Vec>>) { let mut queue = vec![self]; let mut stack = Vec::new(); @@ -150,7 +168,7 @@ pub trait Cluster>: Ord + core::hash::Hash + Size if c.depth() == depth { stack.push(c); } else { - queue.extend(c.child_clusters_mut()); + queue.extend(c.children_mut()); } } @@ -160,36 +178,22 @@ pub trait Cluster>: Ord + core::hash::Hash + Size .for_each(|(c, children)| c.set_children(children)); } - /// Gets the child `Cluster`s. - fn child_clusters<'a>(&'a self) -> impl Iterator - where - U: 'a, - { - self.children().iter().map(|(_, _, child)| child.as_ref()) - } - - /// Gets the child `Cluster`s as mutable references. - fn child_clusters_mut<'a>(&'a mut self) -> impl Iterator - where - U: 'a, - { - self.children_mut().iter_mut().map(|(_, _, child)| child.as_mut()) - } - /// Returns all `Cluster`s in the subtree of this `Cluster`, in depth-first order. fn subtree<'a>(&'a self) -> Vec<&'a Self> where - U: 'a, + T: 'a, { let mut clusters = vec![self]; - self.child_clusters().for_each(|child| clusters.extend(child.subtree())); + self.children() + .into_iter() + .for_each(|child| clusters.extend(child.subtree())); clusters } /// Returns all leaf `Cluster`s in the subtree of this `Cluster`, in depth-first order. fn leaves<'a>(&'a self) -> Vec<&'a Self> where - U: 'a, + T: 'a, { let mut queue = vec![self]; let mut stack = vec![]; @@ -198,7 +202,7 @@ pub trait Cluster>: Ord + core::hash::Hash + Size if cluster.is_leaf() { stack.push(cluster); } else { - queue.extend(cluster.child_clusters()); + queue.extend(cluster.children()); } } @@ -208,7 +212,7 @@ pub trait Cluster>: Ord + core::hash::Hash + Size /// Returns mutable references to all leaf `Cluster`s in the subtree of this `Cluster`, in depth-first order. fn leaves_mut<'a>(&'a mut self) -> Vec<&'a mut Self> where - U: 'a, + T: 'a, { let mut queue = vec![self]; let mut stack = vec![]; @@ -217,7 +221,7 @@ pub trait Cluster>: Ord + core::hash::Hash + Size if cluster.is_leaf() { stack.push(cluster); } else { - queue.extend(cluster.child_clusters_mut()); + queue.extend(cluster.children_mut()); } } @@ -231,24 +235,95 @@ pub trait Cluster>: Ord + core::hash::Hash + Size /// Whether the `Cluster` is a singleton. fn is_singleton(&self) -> bool { - self.cardinality() == 1 || self.radius() < U::EPSILON + self.cardinality() == 1 || self.radius() < T::EPSILON + } + + /// Returns the expected radii for `k` items from each cluster center using + /// clusters whose cardinality is no greater than `k` but whose parents have + /// cardinality greater than `k`. + fn radii_for_k(&self, k: usize) -> Vec { + if self.cardinality() <= k { + vec![F::from(self.radius())] + } else { + self.children() + .into_iter() + .map(|c| (c.cardinality(), c.radii_for_k::(k))) + .flat_map(|(car, radii)| radii.into_iter().map(move |r| (car, r))) + .map(|(c, r)| if c == k { r } else { r * F::from(c) / F::from(k) }) + .collect() + } } - /// Computes the distance from the `Cluster`'s center to a given `query`. - fn distance_to_center(&self, data: &D, query: &I) -> U { - let center = data.get(self.arg_center()); - MetricSpace::one_to_one(data, center, query) + /// Returns whether this cluster has any overlap with a query and a radius, + /// and the distances to the cluster extrema. + fn overlaps_with, M: Metric>( + &self, + data: &D, + metric: &M, + query: &I, + radius: T, + ) -> (bool, Vec) { + let (extrema, extents): (Vec<_>, Vec<_>) = self.extents().iter().copied().unzip(); + let distances = data + .query_to_many(query, &extrema, metric) + .map(|(_, d)| d) + .collect::>(); + (distances.iter().zip(extents).all(|(&d, e)| d <= e + radius), distances) } - /// Computes the distance from the `Cluster`'s center to another `Cluster`'s center. - fn distance_to_other(&self, data: &D, other: &Self) -> U { - Dataset::one_to_one(data, self.arg_center(), other.arg_center()) + /// Returns only those children of the `Cluster` that overlap with a query + /// and a radius. + fn overlapping_children, M: Metric>( + &self, + data: &D, + metric: &M, + query: &I, + radius: T, + ) -> Vec<(&Self, Vec)> { + self.children() + .into_iter() + .map(|c| (c, c.overlaps_with(data, metric, query, radius))) + .filter(|&(_, (o, _))| o) + .map(|(c, (_, ds))| (c, ds)) + .collect() } } /// A parallelized version of the `Cluster` trait. #[allow(clippy::module_name_repetitions)] -pub trait ParCluster>: Cluster + Send + Sync { - /// Parallelized version of the `distances_to_query` method. - fn par_distances_to_query(&self, data: &D, query: &I) -> Vec<(usize, U)>; +pub trait ParCluster: Cluster + Send + Sync { + /// Parallel version of [`Cluster::indices`](crate::core::cluster::Cluster::indices). + fn par_indices(&self) -> impl ParallelIterator; + + /// Parallel version of [`Cluster::overlaps_with`](crate::core::cluster::Cluster::overlaps_with). + fn par_overlaps_with, M: ParMetric>( + &self, + data: &D, + metric: &M, + query: &I, + radius: T, + ) -> (bool, Vec) { + let (extrema, extents): (Vec<_>, Vec<_>) = self.extents().iter().copied().unzip(); + let distances = data + .par_query_to_many(query, &extrema, metric) + .map(|(_, d)| d) + .collect::>(); + (distances.iter().zip(extents).all(|(&d, e)| d <= e + radius), distances) + } + + /// Parallel version of [`Cluster::overlapping_children`](crate::core::cluster::Cluster::overlapping_children). + fn par_overlapping_children, M: ParMetric>( + &self, + data: &D, + metric: &M, + query: &I, + radius: T, + ) -> Vec<(&Self, Vec)> { + self.children() + .into_par_iter() + .map(|c| (c, c.par_overlaps_with(data, metric, query, radius))) + .filter(|&(_, (o, _))| o) + .map(|(c, (_, ds))| (c, ds)) + .collect() + } } diff --git a/crates/abd-clam/src/core/cluster/partition.rs b/crates/abd-clam/src/core/cluster/partition.rs index 6541f8836..1fe4e5848 100644 --- a/crates/abd-clam/src/core/cluster/partition.rs +++ b/crates/abd-clam/src/core/cluster/partition.rs @@ -3,7 +3,13 @@ use distances::Number; use rayon::prelude::*; -use crate::{dataset::ParDataset, Dataset}; +use crate::{ + core::{ + dataset::{Dataset, ParDataset}, + metric::{Metric, ParMetric}, + }, + utils, +}; use super::{Cluster, ParCluster}; @@ -12,45 +18,71 @@ use super::{Cluster, ParCluster}; /// /// # Type Parameters /// -/// - `I`: The type of the instances in the `Dataset`. -/// - `U`: The type of the distance values. -/// - `D`: The type of the `Dataset`. -pub trait Partition>: Cluster { +/// - `T`: The type of the distance values. +/// +/// # Examples +/// +/// See: +/// +/// - [`Ball`](crate::core::cluster::Ball) +/// - [`BalancedBall`](crate::core::cluster::BalancedBall) +pub trait Partition: Cluster + Sized { /// Creates a new `Cluster`. /// /// # Arguments /// - /// - `data`: The dataset containing the instances. - /// - `indices`: The indices of instances in the `Cluster`. + /// - `data`: The dataset containing the items. + /// - `metric`: The metric to use for distance calculations. + /// - `indices`: The indices of items in the `Cluster`. /// - `depth`: The depth of the `Cluster` in the tree. /// - `seed`: An optional seed for random number generation. /// - /// # Returns + /// # Type Parameters /// - /// - The new `Cluster`. - fn new(data: &D, indices: &[usize], depth: usize, seed: Option) -> Self; + /// - `I`: The items in the dataset. + /// - `D`: The dataset. + /// - `M`: The metric. + /// + /// # Errors + /// + /// - If the `indices` are empty. + /// - Any error that occurs when creating the `Cluster` depending on the + /// implementation. + fn new, M: Metric>( + data: &D, + metric: &M, + indices: &[usize], + depth: usize, + seed: Option, + ) -> Result; /// Finds the extrema of the `Cluster`. /// - /// The extrema are meant to be well-separated instances that can be used to + /// The extrema are meant to be well-separated items that can be used to /// partition the `Cluster` into some number of child `Cluster`s. The number /// of children will be equal to the number of extrema determined by this /// method. /// + /// There will be panics if this method returns less than two extrema. + /// /// # Arguments /// - /// - `data`: The dataset containing the instances. + /// - `data`: The dataset containing the items. + /// - `metric`: The metric to use for distance calculations. /// - /// # Returns + /// # Type Parameters /// - /// The extrema to use for partitioning the `Cluster`. - fn find_extrema(&self, data: &D) -> Vec; + /// - `I`: The items in the dataset. + /// - `D`: The dataset. + /// - `M`: The metric. + fn find_extrema, M: Metric>(&mut self, data: &D, metric: &M) -> Vec; /// Creates a new `Cluster` tree. /// /// # Arguments /// - /// - `data`: The dataset containing the instances. + /// - `data`: The dataset containing the items. + /// - `metric`: The metric to use for distance calculations. /// - `criteria`: The function to use for determining when a `Cluster` /// should be partitioned. A `Cluster` will only be partitioned if it is /// not a singleton and this function returns `true`. @@ -59,10 +91,58 @@ pub trait Partition>: Cluster { /// # Returns /// /// - The root `Cluster` of the tree. - fn new_tree bool>(data: &D, criteria: &C, seed: Option) -> Self { + /// + /// # Type Parameters + /// + /// - `I`: The items in the dataset. + /// - `D`: The dataset. + /// - `M`: The metric. + /// - `C`: The criteria function for partitioning. + fn new_tree, M: Metric, C: Fn(&Self) -> bool>( + data: &D, + metric: &M, + criteria: &C, + seed: Option, + ) -> Self { let indices = (0..data.cardinality()).collect::>(); - let mut root = Self::new(data, &indices, 0, seed); - root.partition(data, criteria, seed); + let mut root = Self::new(data, metric, &indices, 0, seed) + .unwrap_or_else(|e| unreachable!("We ensured that the indices are not empty: {e}")); + root.partition(data, metric, criteria, seed); + root + } + + /// Creates a new `Cluster` tree using an iterative partitioning method to + /// avoid stack overflows from the lack of tail call optimization. + fn new_tree_iterative, M: Metric, C: Fn(&Self) -> bool>( + data: &D, + metric: &M, + criteria: &C, + seed: Option, + depth_stride: usize, + ) -> Self { + let mut target_depth = depth_stride; + let stride_criteria = |c: &Self| c.depth() < target_depth && criteria(c); + + let mut root = Self::new_tree(data, metric, &stride_criteria, seed); + + let mut stride_leaves = root + .leaves_mut() + .into_iter() + .filter(|c| (c.depth() == depth_stride) && criteria(c)) + .collect::>(); + while !stride_leaves.is_empty() { + target_depth += depth_stride; + let stride_criteria = |c: &Self| c.depth() < target_depth && criteria(c); + stride_leaves + .into_iter() + .for_each(|c| c.partition(data, metric, &stride_criteria, seed)); + stride_leaves = root + .leaves_mut() + .into_iter() + .filter(|c| (c.depth() == target_depth) && criteria(c)) + .collect::>(); + } + root } @@ -73,26 +153,42 @@ pub trait Partition>: Cluster { /// /// # Arguments /// - /// - `data`: The dataset containing the instances. + /// - `data`: The dataset containing the items. + /// - `metric`: The metric to use for distance calculations. /// - `extrema`: The indices of extrema for partitioning the `Cluster`. - /// - `instances`: The indices of instances in the `Cluster`. + /// - `items`: The indices of items in the `Cluster`. /// /// # Returns /// - /// - The indices of instances with which to initialize the children. The + /// - The indices of items with which to initialize the children. The /// 0th element of each inner `Vec` is the index of the corresponding /// extremum. - /// - The distance from each extremum to the farthest instance assigned to + /// - The distance from each extremum to the farthest item assigned to /// that child, i.e. the "extent" of the child. - fn split_by_extrema(&self, data: &D, extrema: &[usize]) -> (Vec>, Vec) { - let instances = self.indices().filter(|i| !extrema.contains(i)).collect::>(); + /// + /// # Type Parameters + /// + /// - `I`: The items in the dataset. + /// - `D`: The dataset. + /// - `M`: The metric. + fn split_by_extrema, M: Metric>( + &self, + data: &D, + metric: &M, + extrema: &[usize], + ) -> (Vec>, Vec) { + let items = self + .indices() + .into_iter() + .filter(|i| !extrema.contains(i)) + .collect::>(); - // Find the distances from each extremum to each instance. - let extremal_distances = Dataset::many_to_many(data, extrema, &instances); + // Find the distances from each extremum to each item. + let extremal_distances = data.many_to_many(extrema, &items, metric); // Convert the distances from row-major to column-major. let distances = { - let mut distances = vec![vec![U::ZERO; extrema.len()]; instances.len()]; + let mut distances = vec![vec![T::ZERO; extrema.len()]; items.len()]; for (r, row) in extremal_distances.into_iter().enumerate() { for (c, (_, _, d)) in row.into_iter().enumerate() { distances[c][r] = d; @@ -102,17 +198,13 @@ pub trait Partition>: Cluster { }; // Initialize a child stack for each extremum. - let mut child_stacks = extrema.iter().map(|&p| vec![(p, U::ZERO)]).collect::>(); + let mut child_stacks = extrema.iter().map(|&p| vec![(p, T::ZERO)]).collect::>(); - // For each extremum, find the instances that are closer to it than to + // For each extremum, find the items that are closer to it than to // any other extremum. - for (col, instance) in distances.into_iter().zip(instances) { - let (e_index, d) = col - .into_iter() - .enumerate() - .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Greater)) - .unwrap_or_else(|| unreachable!("Cannot find the minimum distance")); - child_stacks[e_index].push((instance, d)); + for (col, item) in distances.into_iter().zip(items) { + let (e_index, d) = utils::arg_min(&col).unwrap_or_else(|| unreachable!("Cannot find the minimum distance")); + child_stacks[e_index].push((item, d)); } child_stacks @@ -121,18 +213,50 @@ pub trait Partition>: Cluster { let (indices, distances) = stack.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); let extent = distances .into_iter() - .max_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)) + .max_by(Number::total_cmp) .unwrap_or_else(|| unreachable!("Cannot find the maximum distance")); (indices, extent) }) .unzip() } + /// Partitions the `Cluster` once instead of recursively. + fn partition_once, M: Metric>( + &mut self, + data: &D, + metric: &M, + seed: Option, + ) -> Vec> { + // Find the extrema. + let extrema = self.find_extrema(data, metric); + // Split the items by the extrema. + let (child_stacks, child_extents) = self.split_by_extrema(data, metric, &extrema); + // Increment the depth for the children. + let depth = self.depth() + 1; + + // Create the children. + child_stacks + .into_iter() + .map(|child_indices| { + Self::new(data, metric, &child_indices, depth, seed) + .unwrap_or_else(|e| unreachable!("We ensured that the indices are not empty: {e}")) + }) + .map(Box::new) + .zip(extrema) + .zip(child_extents) + .map(|((mut c, i), d)| { + c.add_extent(i, d); + c + }) + .collect() + } + /// Recursively partitions the `Cluster` into a tree. /// /// # Arguments /// - /// - `data`: The dataset containing the instances. + /// - `data`: The dataset containing the items. + /// - `metric`: The metric to use for distance calculations. /// - `criteria`: The function to use for determining when a `Cluster` /// should be partitioned. /// - `seed`: An optional seed for random number generation. @@ -140,50 +264,39 @@ pub trait Partition>: Cluster { /// # Returns /// /// - The root `Cluster` of the tree. - /// - The instances in the `Cluster` in depth-first order of traversal of + /// - The items in the `Cluster` in depth-first order of traversal of /// the tree. - fn partition bool>(&mut self, data: &D, criteria: &C, seed: Option) { + /// + /// # Type Parameters + /// + /// - `I`: The items in the dataset. + /// - `D`: The dataset. + /// - `M`: The metric. + /// - `C`: The criteria function for partitioning. + fn partition, M: Metric, C: Fn(&Self) -> bool>( + &mut self, + data: &D, + metric: &M, + criteria: &C, + seed: Option, + ) { if !self.is_singleton() && criteria(self) { ftlog::trace!( - "Starting `partition` of a cluster at depth {}, with {} instances.", + "Starting `partition` of a cluster at depth {}, with {} items.", self.depth(), self.cardinality() ); - // Find the extrema. - let extrema = self.find_extrema(data); - // Split the instances by the extrema. - let (child_stacks, child_extents) = self.split_by_extrema(data, &extrema); - // Increment the depth for the children. - let depth = self.depth() + 1; - // Create the children. - let (children, other) = child_stacks - .into_iter() - .map(|child_indices| { - let mut child = Self::new(data, &child_indices, depth, seed); - let arg_r = child.arg_radial(); - child.partition(data, criteria, seed); - let child_indices = child.indices().collect::>(); - (Box::new(child), (arg_r, child_indices)) - }) - .unzip::<_, _, Vec<_>, Vec<_>>(); - let (arg_extrema, child_stacks) = other.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); - - // Recombine the indices. - self.set_indices(child_stacks.into_iter().flatten().collect()); - - // Combine the children with the extrema and extents. - let children = arg_extrema - .into_iter() - .zip(child_extents) - .zip(children) - .map(|((a, b), c)| (a, b, c)) - .collect(); - // Update the `Cluster`'s children. + let mut children = self.partition_once(data, metric, seed); + for child in &mut children { + child.partition(data, metric, criteria, seed); + } + let indices = children.iter().flat_map(|c| c.indices()).collect::>(); + self.set_indices(&indices); self.set_children(children); ftlog::trace!( - "Finished `partition` of a cluster at depth {}, with {} instances.", + "Finished `partition` of a cluster at depth {}, with {} items.", self.depth(), self.cardinality() ); @@ -192,41 +305,127 @@ pub trait Partition>: Cluster { /// Partitions the leaf `Cluster`s the tree even further using a different /// criteria. - fn partition_further bool>(&mut self, data: &D, criteria: &C, seed: Option) { + /// + /// # Type Parameters + /// + /// - `I`: The items in the dataset. + /// - `D`: The dataset. + /// - `M`: The metric. + /// - `C`: The criteria function for partitioning. + fn partition_further, M: Metric, C: Fn(&Self) -> bool>( + &mut self, + data: &D, + metric: &M, + criteria: &C, + seed: Option, + ) { self.leaves_mut() .into_iter() - .for_each(|child| child.partition(data, criteria, seed)); + .for_each(|child| child.partition(data, metric, criteria, seed)); } } /// `Cluster`s that use and provide parallelized methods. +/// +/// # Examples +/// +/// See: +/// +/// - [`Ball`](crate::core::cluster::Ball) +/// - [`BalancedBall`](crate::core::cluster::BalancedBall) #[allow(clippy::module_name_repetitions)] -pub trait ParPartition>: - ParCluster + Partition -{ - /// Parallelized version of the `new` method. - fn par_new(data: &D, indices: &[usize], depth: usize, seed: Option) -> Self; +pub trait ParPartition: ParCluster + Partition { + /// Parallelized version of [`Partition::new`](crate::core::cluster::Partition::new). + /// + /// # Errors + /// + /// See [`Partition::new`](crate::core::cluster::Partition::new). + fn par_new, M: ParMetric>( + data: &D, + metric: &M, + indices: &[usize], + depth: usize, + seed: Option, + ) -> Result; - /// Parallelized version of the `find_extrema` method. - fn par_find_extrema(&self, data: &D) -> Vec; + /// Parallelized version of [`Partition::find_extrema`](crate::core::cluster::Partition::find_extrema). + fn par_find_extrema, M: ParMetric>( + &mut self, + data: &D, + metric: &M, + ) -> Vec; - /// Parallelized version of the `new_tree` method. - fn par_new_tree bool) + Send + Sync>(data: &D, criteria: &C, seed: Option) -> Self { + /// Parallelized version of [`Partition::new_tree`](crate::core::cluster::Partition::new_tree). + fn par_new_tree, M: ParMetric, C: (Fn(&Self) -> bool) + Send + Sync>( + data: &D, + metric: &M, + criteria: &C, + seed: Option, + ) -> Self { let indices = (0..data.cardinality()).collect::>(); - let mut root = Self::par_new(data, &indices, 0, seed); - root.par_partition(data, criteria, seed); + let mut root = Self::par_new(data, metric, &indices, 0, seed) + .unwrap_or_else(|e| unreachable!("We ensured that the indices are not empty: {e}")); + root.par_partition(data, metric, criteria, seed); + root + } + + /// Parallelized version of [`Partition::new_tree_iterative`](crate::core::cluster::Partition::new_tree_iterative). + fn par_new_tree_iterative< + I: Send + Sync, + D: ParDataset, + M: ParMetric, + C: (Fn(&Self) -> bool) + Send + Sync, + >( + data: &D, + metric: &M, + criteria: &C, + seed: Option, + depth_stride: usize, + ) -> Self { + let mut target_depth = depth_stride; + let stride_criteria = |c: &Self| c.depth() < target_depth && criteria(c); + + let mut root = Self::par_new_tree(data, metric, &stride_criteria, seed); + + let mut stride_leaves = root + .leaves_mut() + .into_par_iter() + .filter(|c| (c.depth() == depth_stride) && criteria(c)) + .collect::>(); + while !stride_leaves.is_empty() { + target_depth += depth_stride; + let stride_criteria = |c: &Self| c.depth() < target_depth && criteria(c); + stride_leaves + .into_par_iter() + .for_each(|c| c.par_partition(data, metric, &stride_criteria, seed)); + stride_leaves = root + .leaves_mut() + .into_par_iter() + .filter(|c| (c.depth() == target_depth) && criteria(c)) + .collect::>(); + } + root } - /// Parallelized version of the `partition_once` method. - fn par_split_by_extrema(&self, data: &D, extrema: &[usize]) -> (Vec>, Vec) { - let instances = self.indices().filter(|i| !extrema.contains(i)).collect::>(); - // Find the distances from each extremum to each instance. - let extremal_distances = ParDataset::par_many_to_many(data, extrema, &instances); + /// Parallelized version of [`Partition::split_by_extrema`](crate::core::cluster::Partition::split_by_extrema). + fn par_split_by_extrema, M: ParMetric>( + &self, + data: &D, + metric: &M, + extrema: &[usize], + ) -> (Vec>, Vec) { + let items = self + .indices() + .into_iter() + .filter(|i| !extrema.contains(i)) + .collect::>(); + // Find the distances from each extremum to each item. + let extremal_distances = data.par_many_to_many(extrema, &items, metric).collect::>(); // Convert the distances from row-major to column-major. let distances = { - let mut distances = vec![vec![U::ZERO; extrema.len()]; instances.len()]; + let mut distances = vec![vec![T::ZERO; extrema.len()]; items.len()]; for (r, row) in extremal_distances.into_iter().enumerate() { for (c, (_, _, d)) in row.into_iter().enumerate() { distances[c][r] = d; @@ -236,91 +435,106 @@ pub trait ParPartition>: }; // Initialize a child stack for each extremum. - let mut child_stacks = extrema.iter().map(|&p| vec![(p, U::ZERO)]).collect::>(); + let mut child_stacks = extrema.iter().map(|&p| vec![(p, T::ZERO)]).collect::>(); - // For each extremum, find the instances that are closer to it than to + // For each extremum, find the items that are closer to it than to // any other extremum. - for (col, instance) in distances.into_iter().zip(instances) { - let (e_index, d) = col - .into_iter() - .enumerate() - .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Greater)) - .unwrap_or_else(|| unreachable!("Cannot find the minimum distance")); - child_stacks[e_index].push((instance, d)); + for (col, item) in distances.into_iter().zip(items) { + let (e_index, d) = utils::arg_min(&col).unwrap_or_else(|| unreachable!("Cannot find the minimum distance")); + child_stacks[e_index].push((item, d)); } - // Find the maximum distance for each child and return the instances. + // Find the maximum distance for each child and return the items. child_stacks .into_par_iter() .map(|stack| { let (indices, distances) = stack.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); let extent = distances .into_iter() - .max_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)) + .max_by(Number::total_cmp) .unwrap_or_else(|| unreachable!("Cannot find the maximum distance")); (indices, extent) }) .unzip() } - /// Parallelized version of the `partition` method. - fn par_partition bool) + Send + Sync>(&mut self, data: &D, criteria: &C, seed: Option) { + /// Parallelized version of [`Partition::partition_once`](crate::core::cluster::Partition::partition_once). + fn par_partition_once, M: ParMetric>( + &mut self, + data: &D, + metric: &M, + seed: Option, + ) -> Vec> { + // Find the extrema. + let extrema = self.par_find_extrema(data, metric); + // Split the items by the extrema. + let (child_stacks, child_extents) = self.par_split_by_extrema(data, metric, &extrema); + // Increment the depth for the children. + let depth = self.depth() + 1; + + // Create the children. + child_stacks + .into_par_iter() + .map(|child_indices| { + Self::par_new(data, metric, &child_indices, depth, seed) + .unwrap_or_else(|e| unreachable!("We ensured that the indices are not empty: {e}")) + }) + .map(Box::new) + .zip(extrema) + .zip(child_extents) + .map(|((mut c, i), d)| { + c.add_extent(i, d); + c + }) + .collect() + } + + /// Parallelized version of [`Partition::partition`](crate::core::cluster::Partition::partition). + fn par_partition, M: ParMetric, C: (Fn(&Self) -> bool) + Send + Sync>( + &mut self, + data: &D, + metric: &M, + criteria: &C, + seed: Option, + ) { if !self.is_singleton() && criteria(self) { ftlog::trace!( - "Starting `par_partition` of a cluster at depth {}, with {} instances.", + "Starting `par_partition` of a cluster at depth {}, with {} items.", self.depth(), self.cardinality() ); - // Find the extrema. - let extrema = self.par_find_extrema(data); - // Split the instances by the extrema. - let (child_stacks, child_extents) = self.par_split_by_extrema(data, &extrema); - // Increment the depth for the children. - let depth = self.depth() + 1; - - // Create the children. - let (children, other) = child_stacks - .into_par_iter() - .map(|child_indices| { - let e_index = child_indices[0]; - let mut child = Self::par_new(data, &child_indices, depth, seed); - child.par_partition(data, criteria, seed); - let child_indices = child.indices().collect::>(); - (Box::new(child), (e_index, child_indices)) - }) - .unzip::<_, _, Vec<_>, Vec<_>>(); - let (arg_extrema, child_stacks) = other.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); - // Recombine the indices from the children. - self.set_indices(child_stacks.into_iter().flatten().collect()); - - // Combine the children with the extrema and extents. - let children = arg_extrema - .into_iter() - .zip(child_extents) - .zip(children) - .map(|((a, b), c)| (a, b, c)) - .collect(); - // Update the `Cluster`'s children. + let mut children = self.par_partition_once(data, metric, seed); + children.par_iter_mut().for_each(|child| { + child.par_partition(data, metric, criteria, seed); + }); + let indices = children.iter().flat_map(|c| c.indices()).collect::>(); + self.set_indices(&indices); self.set_children(children); ftlog::trace!( - "Finished `par_partition` of a cluster at depth {}, with {} instances.", + "Finished `par_partition` of a cluster at depth {}, with {} items.", self.depth(), self.cardinality() ); }; } - /// Parallelized version of the `partition_further` method. - fn par_partition_further bool) + Send + Sync>( + /// Parallelized version of [`Partition::partition_further`](crate::core::cluster::Partition::partition_further). + fn par_partition_further< + I: Send + Sync, + D: ParDataset, + M: ParMetric, + C: (Fn(&Self) -> bool) + Send + Sync, + >( &mut self, data: &D, + metric: &M, criteria: &C, seed: Option, ) { self.leaves_mut() .into_par_iter() - .for_each(|child| child.par_partition(data, criteria, seed)); + .for_each(|child| child.par_partition(data, metric, criteria, seed)); } } diff --git a/crates/abd-clam/src/core/dataset/associates_metadata.rs b/crates/abd-clam/src/core/dataset/associates_metadata.rs new file mode 100644 index 000000000..02480ddb9 --- /dev/null +++ b/crates/abd-clam/src/core/dataset/associates_metadata.rs @@ -0,0 +1,53 @@ +//! An extension of the `Dataset` trait that provides methods for working with +//! metadata associated with items in a dataset. + +use super::Dataset; + +/// A trait that extends the `Dataset` trait with methods for working with +/// metadata associated with items in a dataset. +/// +/// Each item in the dataset should be associated with a piece of metadata. +/// +/// # Type parameters +/// +/// - `I`: The items in the dataset. +/// - `Me`: The metadata associated with each item in the dataset. +pub trait AssociatesMetadata: Dataset { + /// Returns the all metadata associated with the items in the dataset. + fn metadata(&self) -> &[Me]; + + /// Returns the metadata associated with the item at the given index. + fn metadata_at(&self, index: usize) -> &Me; +} + +/// An extension of the `AssociatesMetadata` trait that provides methods for +/// changing the metadata associated with items in a dataset. +/// +/// # Type parameters +/// +/// - `I`: The items in the dataset. +/// - `Me`: The metadata associated with each item in the dataset. +/// - `Met`: The metadata that can be associated after transformation. +/// - `D`: The type of the dataset after the transformation. +#[allow(clippy::module_name_repetitions)] +pub trait AssociatesMetadataMut>: AssociatesMetadata { + /// Returns the all metadata associated with the items in the dataset, + /// mutably. + fn metadata_mut(&mut self) -> &mut [Me]; + + /// Returns the metadata associated with the item at the given index, + /// mutably. + fn metadata_at_mut(&mut self, index: usize) -> &mut Me; + + /// Changes all metadata associated with the items in the dataset. + /// + /// # Errors + /// + /// - If the number of metadata items is not equal to the cardinality of the + /// dataset. + fn with_metadata(self, metadata: &[Met]) -> Result; + + /// Applies a transformation to the metadata associated with the items in + /// the dataset. + fn transform_metadata Met>(self, f: F) -> D; +} diff --git a/crates/abd-clam/src/core/dataset/flat_vec.rs b/crates/abd-clam/src/core/dataset/flat_vec.rs index ae613fb91..cf0ec64f5 100644 --- a/crates/abd-clam/src/core/dataset/flat_vec.rs +++ b/crates/abd-clam/src/core/dataset/flat_vec.rs @@ -1,60 +1,63 @@ -//! A `FlatVec` is a dataset that is stored as a flat vector. +//! A `FlatVec` is a `Dataset` that in which the items are stored in a vector. -use distances::Number; -use serde::{Deserialize, Serialize}; +use rand::Rng; -use super::{metric_space::ParMetricSpace, Dataset, Metric, MetricSpace, ParDataset, Permutable}; +use super::{AssociatesMetadata, AssociatesMetadataMut, Dataset, ParDataset, Permutable}; -/// A `FlatVec` is a dataset that is stored as a flat vector. -/// -/// The instances are stored as a flat vector. +/// A `FlatVec` is a `Dataset` that in which the items are stored in a vector. /// /// # Type Parameters /// -/// - `I`: The type of the instances in the dataset. -/// - `U`: The type of the distance values. -/// - `M`: The type of the metadata associated with the instances. -#[derive(Clone, Serialize, Deserialize)] -pub struct FlatVec { - /// The metric space of the dataset. - #[serde(skip)] - pub(crate) metric: Metric, - /// The instances in the dataset. - pub(crate) instances: Vec, +/// - `I`: The items in the dataset. +/// - `Me`: The metadata associated with the items. +#[derive(Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +pub struct FlatVec { + /// The items in the dataset. + items: Vec, /// A hint for the dimensionality of the dataset. - pub(crate) dimensionality_hint: (usize, Option), - /// The permutation of the instances. - pub(crate) permutation: Vec, - /// The metadata associated with the instances. - pub(crate) metadata: Vec, + dimensionality_hint: (usize, Option), + /// The permutation of the items. + permutation: Vec, + /// The metadata associated with the items. + pub(crate) metadata: Vec, /// The name of the dataset. - pub(crate) name: String, + name: String, } -impl FlatVec { +impl FlatVec { /// Creates a new `FlatVec`. /// - /// # Parameters + /// The metadata is set to the indices of the items. + /// + /// # Errors /// - /// - `instances`: The instances in the dataset. - /// - `metric`: The metric space of the dataset. + /// * If the `items` are empty. /// - /// # Returns + /// # Example /// - /// A new `FlatVec`. + /// ```rust + /// use abd_clam::{Dataset, FlatVec}; /// - /// # Errors + /// let items = vec![1, 2, 3]; + /// let data = FlatVec::new(items).unwrap(); + /// assert_eq!(data.cardinality(), 3); /// - /// * If the instances are empty. - pub fn new(instances: Vec, metric: Metric) -> Result { - if instances.is_empty() { - Err("The instances are empty.".to_string()) + /// let items: Vec = vec![]; + /// let data = FlatVec::new(items); + /// assert!(data.is_err()); + /// ``` + pub fn new(items: Vec) -> Result { + if items.is_empty() { + Err("The items are empty.".to_string()) } else { - let permutation = (0..instances.len()).collect::>(); + let permutation = (0..items.len()).collect::>(); let metadata = permutation.clone(); Ok(Self { - metric, - instances, + items, dimensionality_hint: (0, None), permutation, metadata, @@ -64,67 +67,69 @@ impl FlatVec { } } -impl FlatVec, U, usize> { +impl FlatVec, usize> { /// Creates a new `FlatVec` from tabular data. /// - /// The data are assumed to be a 2d array where each row is an instance. - /// The dimensionality of the dataset is set to the number of columns in the data. + /// The metadata is set to the indices of the items. /// - /// # Parameters + /// The items are assumed to all have the same length. This length is used + /// as the dimensionality of the dataset. /// - /// - `instances`: The instances in the dataset. - /// - `metric`: The metric space of the dataset. + /// # Errors /// - /// # Returns + /// * If the items are empty. + /// * If the items do not all have the same length. /// - /// A new `FlatVec`. + /// # Example /// - /// # Errors + /// ```rust + /// use abd_clam::{Dataset, FlatVec}; /// - /// * If the instances are empty. - pub fn new_array(instances: Vec>, metric: Metric, U>) -> Result { - if instances.is_empty() { - Err("The instances are empty.".to_string()) + /// let items = vec![vec![1, 2], vec![3, 4]]; + /// let data = FlatVec::new_array(items).unwrap(); + /// assert_eq!(data.cardinality(), 2); + /// ``` + pub fn new_array(items: Vec>) -> Result { + if items.is_empty() { + Err("The items are empty.".to_string()) } else { - let dimensionality = instances[0].len(); - let permutation = (0..instances.len()).collect::>(); - let metadata = permutation.clone(); - Ok(Self { - metric, - instances, - dimensionality_hint: (dimensionality, Some(dimensionality)), - permutation, - metadata, - name: "Unknown FlatVec".to_string(), - }) + let (min_len, max_len) = items.iter().fold((usize::MAX, 0), |(min, max), item| { + (min.min(item.len()), max.max(item.len())) + }); + if min_len == max_len { + let permutation = (0..items.len()).collect::>(); + let metadata = permutation.clone(); + Ok(Self { + items, + dimensionality_hint: (min_len, Some(min_len)), + permutation, + metadata, + name: "Unknown FlatVec".to_string(), + }) + } else { + Err(format!( + "The items do not all have the same length. Lengths range from {min_len} to {max_len}." + )) + } } } } -impl FlatVec { - /// Deconstructs the `FlatVec` into its members. +impl FlatVec { + /// Sets a lower bound for the dimensionality of the dataset. /// - /// # Returns + /// # Example /// - /// - The `Metric` of the dataset. - /// - The instances in the dataset. - /// - A hint for the dimensionality of the dataset. - /// - The permutation of the instances. - /// - The metadata associated with the instances. - #[allow(clippy::type_complexity)] - #[must_use] - pub fn deconstruct(self) -> (Metric, Vec, (usize, Option), Vec, Vec, String) { - ( - self.metric, - self.instances, - self.dimensionality_hint, - self.permutation, - self.metadata, - self.name, - ) - } - - /// Sets a lower bound for the dimensionality of the dataset. + /// ```rust + /// use abd_clam::{Dataset, FlatVec}; + /// + /// let items = vec!["hello", "ciao", "classic"]; + /// let data = FlatVec::new(items).unwrap(); + /// assert_eq!(data.dimensionality_hint(), (0, None)); + /// + /// let data = data.with_dim_lower_bound(3); + /// assert_eq!(data.dimensionality_hint(), (3, None)); + /// ``` #[must_use] pub const fn with_dim_lower_bound(mut self, lower_bound: usize) -> Self { self.dimensionality_hint.0 = lower_bound; @@ -132,83 +137,148 @@ impl FlatVec { } /// Sets an upper bound for the dimensionality of the dataset. + /// + /// # Example + /// + /// ```rust + /// use abd_clam::{Dataset, FlatVec}; + /// + /// let items = vec!["hello", "ciao", "classic"]; + /// let data = FlatVec::new(items).unwrap(); + /// assert_eq!(data.dimensionality_hint(), (0, None)); + /// + /// let data = data.with_dim_upper_bound(5); + /// assert_eq!(data.dimensionality_hint(), (0, Some(5))); + /// ``` #[must_use] pub const fn with_dim_upper_bound(mut self, upper_bound: usize) -> Self { self.dimensionality_hint.1 = Some(upper_bound); self } - /// Returns the metadata associated with the instances. - #[must_use] - pub fn metadata(&self) -> &[M] { - &self.metadata - } - - /// Assigns metadata to the instances. + /// Changes the permutation in the dataset without reordering the items. /// - /// # Parameters + /// # Example /// - /// - `metadata`: The metadata to assign to the instances. + /// ```rust + /// use abd_clam::{Dataset, FlatVec, dataset::Permutable}; /// - /// # Returns + /// let items = vec!["hello", "ciao", "classic"]; + /// let data = FlatVec::new(items).unwrap(); + /// assert_eq!(data.permutation(), vec![0, 1, 2]); /// - /// The dataset with the metadata assigned to the instances. + /// let permutation = vec![1, 0, 2]; + /// let data = data.with_permutation(&permutation); + /// assert_eq!(data.permutation(), permutation); + /// ``` + #[must_use] + pub fn with_permutation(mut self, permutation: &[usize]) -> Self { + self.set_permutation(permutation); + self + } + + /// Get the items in the dataset. /// - /// # Errors + /// # Example /// - /// * If the metadata length does not match the number of instances. - pub fn with_metadata(self, mut metadata: Vec) -> Result, String> { - if metadata.len() == self.instances.len() { - metadata.permute(&self.permutation); - Ok(FlatVec { - metric: self.metric, - instances: self.instances, - dimensionality_hint: self.dimensionality_hint, - permutation: self.permutation, - metadata, - name: self.name, - }) - } else { - Err(format!( - "The metadata length does not match the number of instances. {} vs {}", - metadata.len(), - self.instances.len() - )) - } + /// ```rust + /// use abd_clam::{Dataset, FlatVec}; + /// + /// let items = vec![1, 2, 3]; + /// let data = FlatVec::new(items).unwrap(); + /// assert_eq!(data.items(), &[1, 2, 3]); + /// ``` + #[must_use] + pub fn items(&self) -> &[I] { + &self.items } - /// Get the instances in the dataset. + /// Takes the items out of the dataset. #[must_use] - pub fn instances(&self) -> &[I] { - &self.instances + pub fn take_items(self) -> Vec { + self.items } -} -impl MetricSpace for FlatVec { - fn metric(&self) -> &Metric { - &self.metric + /// Transforms the items in the dataset. + /// + /// # Type Parameters + /// + /// - `It`: The transformed items. + /// - `F`: The transformer function. + /// + /// # Example + /// + /// ```rust + /// use abd_clam::{Dataset, FlatVec}; + /// + /// let items = vec![1, 2, 3]; + /// let data = FlatVec::new(items).unwrap(); + /// assert_eq!(data.get(0), &1); + /// assert_eq!(data.get(1), &2); + /// assert_eq!(data.get(2), &3); + /// + /// let f = |x: i32| x * 2; + /// let data = data.transform_items(f); + /// assert_eq!(data.get(0), &2); + /// assert_eq!(data.get(1), &4); + /// assert_eq!(data.get(2), &6); + /// + /// let f = |x: i32| vec![x, x * 2]; + /// let data = data.transform_items(f); + /// assert_eq!(data.get(0), &[2, 4]); + /// assert_eq!(data.get(1), &[4, 8]); + /// assert_eq!(data.get(2), &[6, 12]); + /// + /// let f = |x: Vec| x.into_iter().sum::().to_string(); + /// let data = data.transform_items(f); + /// assert_eq!(data.get(0), "6"); + /// assert_eq!(data.get(1), "12"); + /// assert_eq!(data.get(2), "18"); + /// ``` + pub fn transform_items It>(self, transformer: F) -> FlatVec { + let items = self.items.into_iter().map(transformer).collect(); + FlatVec { + items, + dimensionality_hint: self.dimensionality_hint, + permutation: self.permutation, + metadata: self.metadata, + name: self.name, + } } +} - fn set_metric(&mut self, metric: Metric) { - self.metric = metric; +impl FlatVec { + /// Creates a subsample of the dataset by sampling without replacement. + /// + /// This will inherit `dimensionality_hint` from the original dataset. The + /// permutation will be set to the identity permutation. + #[must_use] + pub fn random_subsample(&self, rng: &mut R, size: usize) -> Self { + let indices = rand::seq::index::sample(rng, self.items.len(), size).into_vec(); + let items = indices.iter().map(|&i| self.items[i].clone()).collect(); + let metadata = indices.iter().map(|&i| self.metadata[i].clone()).collect(); + Self { + items, + dimensionality_hint: self.dimensionality_hint, + permutation: (0..size).collect(), + metadata, + name: self.name.clone(), + } } } -impl ParMetricSpace for FlatVec {} - -impl Dataset for FlatVec { +impl Dataset for FlatVec { fn name(&self) -> &str { &self.name } - /// Changes the name of the dataset. fn with_name(mut self, name: &str) -> Self { self.name = name.to_string(); self } fn cardinality(&self) -> usize { - self.instances.len() + self.items.len() } fn dimensionality_hint(&self) -> (usize, Option) { @@ -216,13 +286,64 @@ impl Dataset for FlatVec { } fn get(&self, index: usize) -> &I { - &self.instances[index] + &self.items[index] } } -impl ParDataset for FlatVec {} +impl ParDataset for FlatVec {} + +impl AssociatesMetadata for FlatVec { + fn metadata(&self) -> &[Me] { + &self.metadata + } + + fn metadata_at(&self, index: usize) -> &Me { + &self.metadata[index] + } +} + +impl AssociatesMetadataMut> for FlatVec { + fn metadata_mut(&mut self) -> &mut [Me] { + &mut self.metadata + } + + fn metadata_at_mut(&mut self, index: usize) -> &mut Me { + &mut self.metadata[index] + } + + fn with_metadata(self, metadata: &[Met]) -> Result, String> { + if metadata.len() == self.items.len() { + let mut metadata = metadata.to_vec(); + metadata.permute(&self.permutation); + Ok(FlatVec { + items: self.items, + dimensionality_hint: self.dimensionality_hint, + permutation: self.permutation, + metadata, + name: self.name, + }) + } else { + Err(format!( + "The metadata length does not match the number of items. {} vs {}", + metadata.len(), + self.items.len() + )) + } + } + + fn transform_metadata Met>(self, f: F) -> FlatVec { + let metadata = self.metadata.iter().map(f).collect(); + FlatVec { + items: self.items, + dimensionality_hint: self.dimensionality_hint, + permutation: self.permutation, + metadata, + name: self.name, + } + } +} -impl Permutable for FlatVec { +impl Permutable for FlatVec { fn permutation(&self) -> Vec { self.permutation.clone() } @@ -232,71 +353,94 @@ impl Permutable for FlatVec { } fn swap_two(&mut self, i: usize, j: usize) { - self.instances.swap(i, j); + self.items.swap(i, j); self.permutation.swap(i, j); self.metadata.swap(i, j); } } -#[cfg(feature = "ndarray-bindings")] -impl FlatVec, U, usize> { - /// Reads a `VecDataset` from a `.npy` file. +#[cfg(feature = "disk-io")] +impl super::DatasetIO + for FlatVec +{ +} + +#[cfg(feature = "disk-io")] +impl + super::ParDatasetIO for FlatVec +{ +} + +#[cfg(feature = "disk-io")] +impl FlatVec, usize> { + /// Reads a `FlatVec` from a `.npy` file. + /// + /// The name of the dataset is set to the name of the file without the + /// extension. /// /// # Parameters /// /// - `path`: The path to the `.npy` file. - /// - `metric`: The metric space of the dataset. - /// - `name`: The name of the dataset. If `None`, the name of the file is used. /// /// # Errors /// /// * If the path is invalid. /// * If the file cannot be read. - /// * If the instances cannot be converted to a `Vec`. - pub fn read_npy>(path: P, metric: Metric, U>) -> Result { - let arr: ndarray::Array2 = ndarray_npy::read_npy(path).map_err(|e| e.to_string())?; - let instances = arr.axis_iter(ndarray::Axis(0)).map(|row| row.to_vec()).collect(); - Self::new_array(instances, metric) + pub fn read_npy>(path: &P) -> Result { + let name = path + .as_ref() + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("") + .to_string(); + let arr: ndarray::Array2 = ndarray_npy::read_npy(path) + .map_err(|e| format!("Could not read npy file: {e}, path: {:?}", path.as_ref()))?; + let items = arr.axis_iter(ndarray::Axis(0)).map(|row| row.to_vec()).collect(); + + Self::new_array(items).map(|data| data.with_name(&name)) } } -#[cfg(feature = "ndarray-bindings")] -impl FlatVec, U, usize> { - /// Writes the `VecDataset` to a `.npy` file in the given directory. +#[cfg(feature = "disk-io")] +impl FlatVec, Me> { + /// Writes the `FlatVec` to a `.npy` file in the given directory. /// /// # Parameters /// - /// - `dir`: The directory in which to write the `.npy` file. - /// - `name`: The name of the file. If `None`, the name of the dataset is used. - /// - /// # Returns - /// - /// The path to the written file. + /// - `path`: The path in which to write the dataset. /// /// # Errors /// /// * If the path is invalid. /// * If the file cannot be created. - /// * If the instances cannot be converted to an `Array2`. + /// * If the items cannot be converted to an `Array2`. /// * If the `Array2` cannot be written. - pub fn write_npy>(&self, dir: P, name: &str) -> Result { - let path = dir.as_ref().join(name); - let shape = (self.instances.len(), self.dimensionality_hint.0); - let v = self.instances.iter().flat_map(|row| row.iter().copied()).collect(); - let arr: ndarray::Array2 = ndarray::Array2::from_shape_vec(shape, v).map_err(|e| e.to_string())?; - ndarray_npy::write_npy(&path, &arr).map_err(|e| e.to_string())?; - Ok(path) + pub fn write_npy>(&self, path: &P) -> Result<(), String> { + let (min_dim, max_dim) = self.dimensionality_hint; + let max_dim = max_dim.ok_or_else(|| "Cannot write FlatVec with unknown dimensionality to npy".to_string())?; + if min_dim != max_dim { + return Err("Cannot write FlatVec with variable dimensionality to npy".to_string()); + } + + let shape = (self.items.len(), min_dim); + let v = self.items.iter().flat_map(|row| row.iter().copied()).collect(); + let arr = ndarray::Array2::::from_shape_vec(shape, v) + .map_err(|e| format!("Could not convert items to Array2: {e}"))?; + ndarray_npy::write_npy(path, &arr) + .map_err(|e| e.to_string()) + .map_err(|e| format!("Could not write npy file: {e}, path: {:?}", path.as_ref()))?; + + Ok(()) } } -#[cfg(feature = "csv")] -impl FlatVec, U, usize> { - /// Reads a `VecDataset` from a `.csv` file. +#[cfg(feature = "disk-io")] +impl FlatVec, usize> { + /// Reads a `FlatVec` from a `.csv` file. /// /// # Parameters /// /// - `path`: The path to the `.csv` file. - /// - `metric`: The metric space of the dataset. /// - `delimiter`: The delimiter used in the `.csv` file. /// - `has_headers`: Whether to treat the first row as headers. /// @@ -304,11 +448,13 @@ impl FlatVec, U, usize> { /// /// * If the path is invalid. /// * If the file cannot be read. - /// * If the types in the file are not convertible to `T` using string parsing. - /// * If the instances cannot be converted to a `Vec`. - pub fn read_csv>(path: P, metric: Metric, U>) -> Result { - let mut reader = csv::ReaderBuilder::new().from_path(path).map_err(|e| e.to_string())?; - let instances = reader + /// * If the types in the file are not parsable as `T`. + /// * If the items cannot be converted to a `Vec`. + pub fn read_csv>(path: &P) -> Result { + let mut reader = csv::ReaderBuilder::new() + .from_path(path) + .map_err(|e| format!("Could not start reading csv file: {e}, path: {:?}", path.as_ref()))?; + let items = reader .records() .map(|record| { record.map_err(|e| e.to_string()).and_then(|record| { @@ -323,13 +469,13 @@ impl FlatVec, U, usize> { }) }) .collect::, _>>()?; - Self::new_array(instances, metric) + Self::new_array(items) } } -#[cfg(feature = "csv")] -impl FlatVec, U, M> { - /// Writes the `VecDataset` to a `.csv` file with the given path. +#[cfg(feature = "disk-io")] +impl FlatVec, M> { + /// Writes the `FlatVec` to a `.csv` file with the given path. /// /// # Parameters /// @@ -340,267 +486,16 @@ impl FlatVec, U, M> { /// /// * If the path is invalid. /// * If the file cannot be created. - pub fn to_csv>(&self, path: P, delimiter: u8) -> Result<(), String> { + pub fn to_csv>(&self, path: &P, delimiter: u8) -> Result<(), String> { let mut writer = csv::WriterBuilder::new() .delimiter(delimiter) .from_path(path) - .map_err(|e| e.to_string())?; - for instance in &self.instances { + .map_err(|e| format!("Could not start csv file: {e}, path: {:?}", path.as_ref()))?; + for item in &self.items { writer - .write_record(instance.iter().map(T::to_string)) - .map_err(|e| e.to_string())?; - } - Ok(()) - } -} - -/// Tests for the `FlatVec` struct. -#[cfg(test)] -mod tests { - use crate::{dataset::ParDataset, Dataset, Metric, MetricSpace, Permutable}; - - use super::FlatVec; - - #[test] - fn creation() -> Result<(), String> { - let instances = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - - let dataset = FlatVec::new(instances.clone(), metric.clone())?; - assert_eq!(dataset.cardinality(), 3); - assert_eq!(dataset.dimensionality_hint(), (0, None)); - - let dataset = FlatVec::new_array(instances, metric.clone())?; - assert_eq!(dataset.cardinality(), 3); - assert_eq!(dataset.dimensionality_hint(), (2, Some(2))); - - Ok(()) - } - - #[cfg(feature = "ndarray-bindings")] - #[test] - fn npy_io() -> Result<(), String> { - let instances = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - let dataset = FlatVec::new_array(instances, metric.clone())?; - - let tmp_dir = tempdir::TempDir::new("testing").map_err(|e| e.to_string())?; - let path = dataset.write_npy(&tmp_dir, "test.npy")?; - - let new_dataset = FlatVec::read_npy(&path, metric.clone())?; - assert_eq!(new_dataset.cardinality(), 3); - assert_eq!(new_dataset.dimensionality_hint(), (2, Some(2))); - for i in 0..dataset.cardinality() { - assert_eq!(dataset.get(i), new_dataset.get(i)); - } - - let path = dataset.write_npy(&tmp_dir, "test-test.npy")?; - assert_eq!(path.file_name().unwrap().to_str().unwrap(), "test-test.npy"); - - let new_dataset = FlatVec::read_npy(&path, metric.clone())?; - assert_eq!(new_dataset.cardinality(), 3); - assert_eq!(new_dataset.dimensionality_hint(), (2, Some(2))); - for i in 0..dataset.cardinality() { - assert_eq!(dataset.get(i), new_dataset.get(i)); - } - - let new_dataset = FlatVec::read_npy(&path, metric)?; - assert_eq!(new_dataset.cardinality(), 3); - assert_eq!(new_dataset.dimensionality_hint(), (2, Some(2))); - for i in 0..dataset.cardinality() { - assert_eq!(dataset.get(i), new_dataset.get(i)); - } - - Ok(()) - } - - #[test] - fn metricity() -> Result<(), String> { - let instances = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - let dataset = FlatVec::new_array(instances, metric)?; - - assert_eq!(dataset.get(0), &vec![1, 2]); - assert_eq!(dataset.get(1), &vec![3, 4]); - assert_eq!(dataset.get(2), &vec![5, 6]); - - assert_eq!(Dataset::one_to_one(&dataset, 0, 1), 4); - assert_eq!(Dataset::one_to_one(&dataset, 1, 2), 4); - assert_eq!(Dataset::one_to_one(&dataset, 2, 0), 8); - assert_eq!( - Dataset::one_to_many(&dataset, 0, &[0, 1, 2]), - vec![(0, 0), (1, 4), (2, 8)] - ); - assert_eq!( - Dataset::many_to_many(&dataset, &[0, 1], &[1, 2]), - vec![vec![(0, 1, 4), (0, 2, 8)], vec![(1, 1, 0), (1, 2, 4)]] - ); - - assert_eq!(dataset.query_to_one(&vec![0, 0], 0), 3); - assert_eq!(dataset.query_to_one(&vec![0, 0], 1), 7); - assert_eq!(dataset.query_to_one(&vec![0, 0], 2), 11); - assert_eq!( - dataset.query_to_many(&vec![0, 0], &[0, 1, 2]), - vec![(0, 3), (1, 7), (2, 11)] - ); - - Ok(()) - } - - #[test] - fn permutations() -> Result<(), String> { - struct SwapTracker { - data: FlatVec, i32, usize>, - count: usize, + .write_record(item.iter().map(T::to_string)) + .map_err(|e| format!("Could not write record to csv: {e}"))?; } - - impl MetricSpace, i32> for SwapTracker { - fn metric(&self) -> &Metric, i32> { - &self.data.metric - } - - fn set_metric(&mut self, metric: Metric, i32>) { - self.data.metric = metric; - } - } - - impl Dataset, i32> for SwapTracker { - fn name(&self) -> &str { - self.data.name() - } - - fn with_name(mut self, name: &str) -> Self { - self.data = self.data.with_name(name); - self - } - - fn cardinality(&self) -> usize { - self.data.cardinality() - } - - fn dimensionality_hint(&self) -> (usize, Option) { - self.data.dimensionality_hint() - } - - fn get(&self, index: usize) -> &Vec { - self.data.get(index) - } - } - - impl Permutable for SwapTracker { - fn permutation(&self) -> Vec { - self.data.permutation() - } - - fn set_permutation(&mut self, permutation: &[usize]) { - self.data.set_permutation(permutation); - } - - fn swap_two(&mut self, i: usize, j: usize) { - self.data.swap_two(i, j); - self.count += 1; - } - } - - let instances = vec![ - vec![1, 2], - vec![3, 4], - vec![5, 6], - vec![7, 8], - vec![9, 10], - vec![11, 12], - ]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - let data = FlatVec::new_array(instances.clone(), metric.clone())?; - let mut swap_tracker = SwapTracker { data, count: 0 }; - - swap_tracker.swap_two(0, 2); - assert_eq!(swap_tracker.permutation(), &[2, 1, 0, 3, 4, 5]); - assert_eq!(swap_tracker.count, 1); - for (i, &j) in swap_tracker.permutation().iter().enumerate() { - assert_eq!(swap_tracker.get(i), &instances[j]); - } - - swap_tracker.swap_two(0, 4); - assert_eq!(swap_tracker.permutation(), &[4, 1, 0, 3, 2, 5]); - assert_eq!(swap_tracker.count, 2); - for (i, &j) in swap_tracker.permutation().iter().enumerate() { - assert_eq!(swap_tracker.get(i), &instances[j]); - } - - let data = FlatVec::new_array(instances.clone(), metric)?; - let mut data = SwapTracker { data, count: 0 }; - let permutation = vec![2, 1, 0, 5, 4, 3]; - data.permute(&permutation); - assert_eq!(data.permutation(), permutation); - assert_eq!(data.count, 2); - for (i, &j) in data.permutation().iter().enumerate() { - assert_eq!(data.get(i), &instances[j]); - } - - Ok(()) - } - - #[test] - fn linear_search() -> Result<(), String> { - let instances = vec![ - vec![1, 2], - vec![3, 4], - vec![5, 6], - vec![7, 8], - vec![9, 10], - vec![11, 12], - ]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - - let dataset = FlatVec::new_array(instances, metric)?; - let query = vec![3, 3]; // distances: [3, 1, 5, 9, 13, 17] - - let mut result = dataset.knn(&query, 2); - result.sort_unstable_by_key(|x| x.0); - assert_eq!(result, vec![(0, 3), (1, 1)]); - - let mut result = dataset.rnn(&query, 5); - result.sort_unstable_by_key(|x| x.0); - assert_eq!(result, vec![(0, 3), (1, 1), (2, 5)]); - - let mut result = dataset.par_knn(&query, 2); - result.sort_unstable_by_key(|x| x.0); - assert_eq!(result, vec![(0, 3), (1, 1)]); - - let mut result = dataset.par_rnn(&query, 5); - result.sort_unstable_by_key(|x| x.0); - assert_eq!(result, vec![(0, 3), (1, 1), (2, 5)]); - - Ok(()) - } - - #[test] - fn ser_de() -> Result<(), String> { - type Fv = FlatVec, i32, usize>; - - let instances = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; - let distance_function = |a: &Vec, b: &Vec| distances::vectors::manhattan(a, b); - let metric = Metric::new(distance_function, false); - let dataset: Fv = FlatVec::new_array(instances.clone(), metric.clone())?; - - let serialized: Vec = bincode::serialize(&dataset).map_err(|e| e.to_string())?; - let mut deserialized: Fv = bincode::deserialize(&serialized).map_err(|e| e.to_string())?; - deserialized.set_metric(metric); - - assert_eq!(dataset.cardinality(), deserialized.cardinality()); - assert_eq!(dataset.dimensionality_hint(), deserialized.dimensionality_hint()); - assert_eq!(dataset.permutation(), deserialized.permutation()); - assert_eq!(dataset.metadata(), deserialized.metadata()); - for i in 0..dataset.cardinality() { - assert_eq!(dataset.get(i), deserialized.get(i)); - } - Ok(()) } } diff --git a/crates/abd-clam/src/core/dataset/io.rs b/crates/abd-clam/src/core/dataset/io.rs new file mode 100644 index 000000000..2b561b5a9 --- /dev/null +++ b/crates/abd-clam/src/core/dataset/io.rs @@ -0,0 +1,55 @@ +//! Traits for Disk IO operations with datasets. + +use super::{Dataset, ParDataset}; + +#[cfg(feature = "disk-io")] +/// For writing and reading datasets to and from disk. +pub trait DatasetIO: Dataset + bitcode::Encode + bitcode::Decode { + /// Writes the `Dataset` to disk in binary format using `bitcode`. + /// + /// # Errors + /// + /// - If the dataset cannot be encoded. + /// - If the file cannot be written. + fn write_to>(&self, path: &P) -> Result<(), String> { + let bytes = bitcode::encode(self).map_err(|e| e.to_string())?; + std::fs::write(path, bytes).map_err(|e| e.to_string()) + } + + /// Reads the `Dataset` from disk in binary format using `bitcode`. + /// + /// # Errors + /// + /// - If the file cannot be read. + /// - If the dataset cannot be decoded. + fn read_from>(path: &P) -> Result { + let bytes = std::fs::read(path).map_err(|e| e.to_string())?; + bitcode::decode(&bytes).map_err(|e| e.to_string()) + } +} + +#[cfg(feature = "disk-io")] +/// Parallel version of [`DatasetIO`](crate::core::dataset::io::DatasetIO). +pub trait ParDatasetIO: DatasetIO + ParDataset { + /// Parallel version of [`DatasetIO::write_to`](crate::core::dataset::io::DatasetIO::write_to). + /// + /// The default implementation offers no parallelism. + /// + /// # Errors + /// + /// See [`DatasetIO::write_to`](crate::core::dataset::io::DatasetIO::write_to). + fn par_write_to>(&self, path: &P) -> Result<(), String> { + self.write_to(path) + } + + /// Parallel version of [`DatasetIO::read_from`](crate::core::dataset::io::DatasetIO::read_from). + /// + /// The default implementation offers no parallelism. + /// + /// # Errors + /// + /// See [`DatasetIO::read_from`](crate::core::dataset::io::DatasetIO::read_from). + fn par_read_from>(path: &P) -> Result { + Self::read_from(path) + } +} diff --git a/crates/abd-clam/src/core/dataset/metric.rs b/crates/abd-clam/src/core/dataset/metric.rs deleted file mode 100644 index b6e094bde..000000000 --- a/crates/abd-clam/src/core/dataset/metric.rs +++ /dev/null @@ -1,224 +0,0 @@ -//! A `Metric` is a wrapper for a distance function that provides information -//! about the properties of the distance function. - -use serde::{Deserialize, Serialize}; - -/// A `Metric` is a wrapper for a distance function that provides information -/// about the properties of the distance function. -/// -/// # Type Parameters -/// -/// - `I`: The type of inputs to the distance function. -/// - `U`: The type of the distance values. -#[allow(clippy::struct_excessive_bools)] -#[derive(Clone)] -pub struct Metric { - /// Whether the distance function provides an identity. - pub(crate) identity: bool, - /// Whether the distance function is non-negative. - pub(crate) non_negativity: bool, - /// Whether the distance function is symmetric. - pub(crate) symmetry: bool, - /// Whether the distance function satisfies the triangle inequality. - pub(crate) triangle_inequality: bool, - /// Whether the distance function is expensive to compute. - pub(crate) expensive: bool, - /// The distance function. - pub(crate) distance_function: fn(&I, &I) -> U, - /// The name of the distance function. - pub(crate) name: String, -} - -impl Default for Metric { - fn default() -> Self { - Self { - identity: true, - non_negativity: true, - symmetry: true, - triangle_inequality: true, - expensive: false, - distance_function: |_, _| unreachable!("This should never be called."), - name: "Unknown Metric".to_string(), - } - } -} - -#[allow(clippy::missing_fields_in_debug)] -impl core::fmt::Debug for Metric { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Metric") - .field("name", &self.name) - .field("identity", &self.identity) - .field("non_negativity", &self.non_negativity) - .field("symmetry", &self.symmetry) - .field("triangle_inequality", &self.triangle_inequality) - .field("expensive", &self.expensive) - .finish() - } -} - -impl core::fmt::Display for Metric { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} Metric", self.name) - } -} - -impl Metric { - /// Creates a new `Metric`. - /// - /// This sets the `identity`, `non_negativity`, `symmetry`, and - /// `triangle_inequality` properties to `true`. - /// - /// # Parameters - /// - /// - `distance_function`: The distance function. - /// - `expensive`: Whether the distance function is expensive to compute. - pub fn new(distance_function: fn(&I, &I) -> U, expensive: bool) -> Self { - Self { - identity: true, - non_negativity: true, - symmetry: true, - triangle_inequality: true, - expensive, - distance_function, - name: "Unknown Metric".to_string(), - } - } - - /// Returns the name of the distance function. - #[must_use] - pub fn name(&self) -> &str { - &self.name - } - - /// Sets the name of the distance function. - #[must_use] - pub fn with_name(mut self, name: &str) -> Self { - self.name = name.to_string(); - self - } - - /// Specifies that this distance function provides an identity. - #[must_use] - pub const fn has_identity(mut self) -> Self { - self.identity = true; - self - } - - /// Specifies that this distance function does not provide an identity. - #[must_use] - pub const fn no_identity(mut self) -> Self { - self.identity = false; - self - } - - /// Specifies that this distance function is non-negative. - #[must_use] - pub const fn has_non_negativity(mut self) -> Self { - self.non_negativity = true; - self - } - - /// Specifies that this distance function is not non-negative. - #[must_use] - pub const fn no_non_negativity(mut self) -> Self { - self.non_negativity = false; - self - } - - /// Specifies that this distance function is symmetric. - #[must_use] - pub const fn has_symmetry(mut self) -> Self { - self.symmetry = true; - self - } - - /// Specifies that this distance function is not symmetric. - #[must_use] - pub const fn no_symmetry(mut self) -> Self { - self.symmetry = false; - self - } - - /// Specifies that this distance function satisfies the triangle inequality. - #[must_use] - pub const fn has_triangle_inequality(mut self) -> Self { - self.triangle_inequality = true; - self - } - - /// Specifies that this distance function does not satisfy the triangle - /// inequality. - #[must_use] - pub const fn no_triangle_inequality(mut self) -> Self { - self.triangle_inequality = false; - self - } - - /// Specifies that this distance function is expensive to compute. - #[must_use] - pub const fn is_expensive(mut self) -> Self { - self.expensive = true; - self - } - - /// Specifies that this distance function is not expensive to compute. - #[must_use] - pub const fn is_not_expensive(mut self) -> Self { - self.expensive = false; - self - } -} - -/// A `SerdeMetric` is a helper for serializing and deserializing a `Metric`. -#[allow(clippy::struct_excessive_bools)] -#[derive(Serialize, Deserialize)] -struct SerdeMetric { - /// Whether the distance function provides an identity. - identity: bool, - /// Whether the distance function is non-negative. - non_negativity: bool, - /// Whether the distance function is symmetric. - symmetry: bool, - /// Whether the distance function satisfies the triangle inequality. - triangle_inequality: bool, - /// Whether the distance function is expensive to compute. - expensive: bool, - /// The name of the distance function. - name: String, - /// A phantom data field to ensure that the compiler is satisfied. - _phantom: std::marker::PhantomData<(I, U)>, -} - -impl Serialize for Metric { - fn serialize(&self, serializer: S) -> Result { - let serde_metric = SerdeMetric:: { - identity: self.identity, - non_negativity: self.non_negativity, - symmetry: self.symmetry, - triangle_inequality: self.triangle_inequality, - expensive: self.expensive, - name: self.name.clone(), - _phantom: std::marker::PhantomData, - }; - serde_metric.serialize(serializer) - } -} - -impl<'de, I, U> Deserialize<'de> for Metric { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let serde_metric = SerdeMetric::::deserialize(deserializer)?; - Ok(Self { - identity: serde_metric.identity, - non_negativity: serde_metric.non_negativity, - symmetry: serde_metric.symmetry, - triangle_inequality: serde_metric.triangle_inequality, - expensive: serde_metric.expensive, - distance_function: |_, _| unreachable!("This should never be called."), - name: serde_metric.name, - }) - } -} diff --git a/crates/abd-clam/src/core/dataset/metric_space.rs b/crates/abd-clam/src/core/dataset/metric_space.rs deleted file mode 100644 index 22a412ca1..000000000 --- a/crates/abd-clam/src/core/dataset/metric_space.rs +++ /dev/null @@ -1,346 +0,0 @@ -//! `MetricSpace` is a trait for datasets that have a distance function. - -use distances::Number; -use rand::prelude::*; -use rayon::prelude::*; - -use super::Metric; - -/// `MetricSpace` is a trait for datasets that have a distance function. -/// -/// # Type Parameters -/// -/// * `I`: The type of the instances, i.e. each data point in the dataset. -/// * `U`: The type of the distance values. -pub trait MetricSpace { - /// Returns the underlying metric. - fn metric(&self) -> &Metric; - - /// Changes the underlying metric. - fn set_metric(&mut self, metric: Metric); - - /// Whether the distance function provides an identity. - /// - /// Identity is defined as: `d(x, y) = 0 <=> x = y`. - fn identity(&self) -> bool { - self.metric().identity - } - - /// Whether the distance function is non-negative. - /// - /// Non-negativity is defined as: `d(x, y) >= 0`. - fn non_negativity(&self) -> bool { - self.metric().non_negativity - } - - /// Whether the distance function is symmetric. - /// - /// Symmetry is defined as: `d(x, y) = d(y, x) for all x, y`. - fn symmetry(&self) -> bool { - self.metric().symmetry - } - - /// Whether the distance function satisfies the triangle inequality. - /// - /// The triangle inequality is defined as: `d(x, y) + d(y, z) >= d(x, z) for all x, y, z`. - fn triangle_inequality(&self) -> bool { - self.metric().triangle_inequality - } - - /// Whether the distance function is expensive to compute. - /// - /// A distance function is expensive if its asymptotic complexity is greater - /// than O(n) where n is the dimensionality of the dataset. - fn expensive(&self) -> bool { - self.metric().expensive - } - - /// Returns the distance function - fn distance_function(&self) -> fn(&I, &I) -> U { - self.metric().distance_function - } - - /// Calculates the distance between two instances. - fn one_to_one(&self, a: &I, b: &I) -> U { - self.distance_function()(a, b) - } - - /// Whether two instances are equal. - /// - /// This will always return `false` when the distance function does not - /// provide an identity. - fn equal(&self, a: &I, b: &I) -> bool { - self.identity() && (self.one_to_one(a, b) == U::ZERO) - } - - /// Calculates the distances between an instance and a collection of instances. - fn one_to_many(&self, a: &I, b: &[(A, &I)]) -> Vec<(A, U)> { - b.iter().map(|&(i, x)| (i, self.one_to_one(a, x))).collect() - } - - /// Calculates the distances between two collections of instances. - fn many_to_many(&self, a: &[(A, &I)], b: &[(B, &I)]) -> Vec> { - a.iter() - .map(|&(i, x)| b.iter().map(move |&(j, y)| (i, j, self.one_to_one(x, y))).collect()) - .collect() - } - - /// Calculates the distances between all given pairs of instances. - fn pairs(&self, pairs: &[(A, &I, B, &I)]) -> Vec<(A, B, U)> { - pairs - .iter() - .map(|&(i, a, j, b)| (i, j, self.one_to_one(a, b))) - .collect() - } - - /// Calculates all pairwise distances between instances. - fn pairwise(&self, a: &[(A, &I)]) -> Vec> { - if self.symmetry() { - let mut matrix = a - .iter() - .map(|&(i, _)| a.iter().map(move |&(j, _)| (i, j, U::ZERO)).collect::>()) - .collect::>(); - - for (i, &(p, x)) in a.iter().enumerate() { - let pairs = a.iter().skip(i + 1).map(|&(q, y)| (p, x, q, y)).collect::>(); - let distances = self.pairs(&pairs); - distances - .into_iter() - .enumerate() - .map(|(j, d)| (j + i + 1, d)) - .for_each(|(j, (p, q, d))| { - matrix[i][j] = (p, q, d); - matrix[j][i] = (q, p, d); - }); - } - - if !self.identity() { - // compute the diagonal for non-metrics - let pairs = a.iter().map(|&(p, x)| (p, x, p, x)).collect::>(); - let distances = self.pairs(&pairs); - distances - .into_iter() - .enumerate() - .for_each(|(i, (p, q, d))| matrix[i][i] = (p, q, d)); - } - - matrix - } else { - self.many_to_many(a, a) - } - } - - /// Chooses a subset of instances that are unique with respect to the - /// distance function. - /// - /// If the distance function does not provide an identity, this will return - /// a random subset of the instances. - /// - /// # Arguments - /// - /// * `instances` - A collection of instances. - /// * `choose` - The number of instances to choose. - /// * `seed` - A seed for the random number generator. - /// - /// # Returns - /// - /// The indices of the chosen instances in `instances`. - fn choose_unique<'a, A: Copy>( - &'a self, - instances: &[(A, &'a I)], - choose: usize, - seed: Option, - ) -> Vec<(A, &'a I)> { - let mut rng = seed.map_or_else(StdRng::from_entropy, StdRng::seed_from_u64); - - if self.identity() { - let mut choices = Vec::with_capacity(choose); - for &(i, a) in instances { - if !choices.iter().any(|&(_, b)| self.equal(a, b)) { - choices.push((i, a)); - } - - if choices.len() == choose { - break; - } - } - choices - } else { - let mut instances = instances.to_vec(); - instances.shuffle(&mut rng); - instances.truncate(choose); - instances - } - } - - /// Calculates the geometric median of a collection of instances. - /// - /// The geometric median is the instance that minimizes the sum of distances - /// to all other instances. - /// - /// # Arguments - /// - /// * `instances` - A collection of instances. - /// - /// # Returns - /// - /// The index of the geometric median in `instances`. - fn median<'a, A: Copy>(&'a self, instances: &[(A, &'a I)]) -> (A, &'a I) { - let distances = self.pairwise(instances); - let mut median = 0; - let mut min_sum = U::MAX; - - for (i, row) in distances.into_iter().enumerate() { - let sum = row.into_iter().map(|(_, _, d)| d).sum(); - if sum < min_sum { - min_sum = sum; - median = i; - } - } - - instances[median] - } -} - -/// An extension of `MetricSpace` that provides parallel implementations of -/// distance calculations. -#[allow(clippy::module_name_repetitions)] -pub trait ParMetricSpace: MetricSpace + Send + Sync { - /// Calculates the distances between an instance and a collection of - /// instances, in parallel. - fn par_one_to_many(&self, a: &I, b: &[(A, &I)]) -> Vec<(A, U)> { - b.par_iter().map(|&(i, x)| (i, self.one_to_one(a, x))).collect() - } - - /// Calculates the distances between two collections of instances, in parallel. - fn par_many_to_many( - &self, - a: &[(A, &I)], - b: &[(B, &I)], - ) -> Vec> { - a.par_iter() - .map(|&(i, x)| b.par_iter().map(move |&(j, y)| (i, j, self.one_to_one(x, y))).collect()) - .collect() - } - - /// Calculates the distances between all given pairs of instances, in parallel. - #[allow(clippy::type_complexity)] - fn par_pairs(&self, pairs: &[(A, &I, B, &I)]) -> Vec<(A, B, U)> { - pairs - .par_iter() - .map(|&(i, a, j, b)| (i, j, self.one_to_one(a, b))) - .collect() - } - - /// Calculates all pairwise distances between instances, in parallel. - fn par_pairwise(&self, a: &[(A, &I)]) -> Vec> { - if self.symmetry() { - let mut matrix = a - .iter() - .map(|&(i, _)| a.iter().map(move |&(j, _)| (i, j, U::ZERO)).collect::>()) - .collect::>(); - - for (i, &(p, x)) in a.iter().enumerate() { - let pairs = a - .iter() - .skip(i + 1) - .map(move |&(q, y)| (p, x, q, y)) - .collect::>(); - let distances = self.par_pairs(&pairs); - distances - .into_iter() - .enumerate() - .map(|(j, d)| (j + i + 1, d)) - .for_each(|(j, (a, b, d))| { - matrix[i][j] = (a, b, d); - matrix[j][i] = (a, b, d); - }); - } - - if !self.identity() { - // compute the diagonal for non-metrics - let pairs = a.iter().map(|&(i, x)| (i, x, i, x)).collect::>(); - let distances = self.par_pairs(&pairs); - distances - .into_iter() - .enumerate() - .for_each(|(i, (a, b, d))| matrix[i][i] = (a, b, d)); - } - - matrix - } else { - self.par_many_to_many(a, a) - } - } - - /// Chooses a subset of instances that are unique with respect to the - /// distance function. - /// - /// If the distance function does not provide an identity, this will return - /// a random subset of the instances. - /// - /// # Arguments - /// - /// * `instances` - A collection of instances. - /// * `choose` - The number of instances to choose. - /// * `seed` - A seed for the random number generator. - /// - /// # Returns - /// - /// The indices of the chosen instances in `instances`. - fn par_choose_unique<'a, A: Copy + Send + Sync>( - &'a self, - instances: &[(A, &'a I)], - choose: usize, - seed: Option, - ) -> Vec<(A, &'a I)> { - let mut rng = seed.map_or_else(StdRng::from_entropy, StdRng::seed_from_u64); - - if self.identity() { - let mut choices = Vec::with_capacity(choose); - for &(i, a) in instances { - if !choices.par_iter().any(|&(_, b)| self.equal(a, b)) { - choices.push((i, a)); - } - - if choices.len() == choose { - break; - } - } - choices - } else { - let mut instances = instances.to_vec(); - instances.shuffle(&mut rng); - instances.truncate(choose); - instances - } - } - - /// Calculates the geometric median of a collection of instances. - /// - /// The geometric median is the instance that minimizes the sum of distances - /// to all other instances. - /// - /// # Arguments - /// - /// * `instances` - A collection of instances. - /// - /// # Returns - /// - /// The index of the geometric median in `instances`. - fn par_median<'a, A: Copy + Send + Sync>(&self, instances: &[(A, &'a I)]) -> (A, &'a I) { - let distances = self.par_pairwise(instances); - let mut median = 0; - let mut min_sum = U::MAX; - - for (i, row) in distances.into_iter().enumerate() { - let sum = row.into_iter().map(|(_, _, d)| d).sum(); - if sum < min_sum { - min_sum = sum; - median = i; - } - } - - instances[median] - } -} diff --git a/crates/abd-clam/src/core/dataset/mod.rs b/crates/abd-clam/src/core/dataset/mod.rs index eb4d85766..0193d8752 100644 --- a/crates/abd-clam/src/core/dataset/mod.rs +++ b/crates/abd-clam/src/core/dataset/mod.rs @@ -1,26 +1,41 @@ //! Traits relating to datasets. +use distances::Number; +use rand::prelude::*; +use rayon::prelude::*; + +use super::{metric::ParMetric, Metric}; + +mod associates_metadata; mod flat_vec; -mod metric; -pub mod metric_space; mod permutable; +mod sized_heap; -use distances::Number; - +pub use associates_metadata::{AssociatesMetadata, AssociatesMetadataMut}; pub use flat_vec::FlatVec; -pub use metric::Metric; -pub use metric_space::MetricSpace; pub use permutable::Permutable; +pub use sized_heap::SizedHeap; -use metric_space::ParMetricSpace; +#[cfg(feature = "disk-io")] +mod io; -/// A dataset is a collection of instances. +#[cfg(feature = "disk-io")] +#[allow(clippy::module_name_repetitions)] +pub use io::{DatasetIO, ParDatasetIO}; + +/// A dataset is a collection of items. /// /// # Type Parameters /// -/// - `I`: The type of the instances. -/// - `U`: The type of the distance values. -pub trait Dataset: MetricSpace { +/// - `I`: The type of the items. +/// +/// # Example +/// +/// See: +/// +/// - [`FlatVec`](crate::core::dataset::FlatVec) +/// - [`CodecData`](crate::pancakes::CodecData) +pub trait Dataset { /// Returns the name of the dataset. fn name(&self) -> &str; @@ -28,7 +43,7 @@ pub trait Dataset: MetricSpace { #[must_use] fn with_name(self, name: &str) -> Self; - /// Returns the number of instances in the dataset. + /// Returns the number of items in the dataset. fn cardinality(&self) -> usize; /// A range of values for the dimensionality of the dataset. @@ -37,259 +52,285 @@ pub trait Dataset: MetricSpace { /// bound. fn dimensionality_hint(&self) -> (usize, Option); - /// Returns the instance at the given index. May panic if the index is out + /// Returns the item at the given index. May panic if the index is out /// of bounds. fn get(&self, index: usize) -> &I; - /// Computes the distance between two instances by their indices. - fn one_to_one(&self, i: usize, j: usize) -> U { - MetricSpace::one_to_one(self, self.get(i), self.get(j)) + /// Returns an iterator over the indices of the items. + fn indices(&self) -> impl Iterator { + 0..self.cardinality() } - /// Computes the distances between a query instance and the given instance by its index. - fn query_to_one(&self, query: &I, j: usize) -> U { - MetricSpace::one_to_one(self, query, self.get(j)) + /// Computes the distance between two items by their indices. + fn one_to_one>(&self, i: usize, j: usize, metric: &M) -> T { + metric.distance(self.get(i), self.get(j)) } - /// Computes the distances between an instance and a collection of instances. - fn one_to_many(&self, i: usize, j: &[usize]) -> Vec<(usize, U)> { - let a = self.get(i); - let b = j.iter().map(|&i| (i, self.get(i))).collect::>(); - MetricSpace::one_to_many(self, a, &b) - } - - /// Computes the distances between the query and the given collection of instances. - fn query_to_many(&self, query: &I, j: &[usize]) -> Vec<(usize, U)> { - let b = j.iter().map(|&i| (i, self.get(i))).collect::>(); - MetricSpace::one_to_many(self, query, &b) + /// Computes the distances between an item and a collection of items. + /// + /// Each tuple `(j, d)` represents the distance between the items at + /// indices `i` and `j`. + fn one_to_many>( + &self, + i: usize, + js: &[usize], + metric: &M, + ) -> impl Iterator { + js.iter().map(move |&j| (j, self.one_to_one(i, j, metric))) } - /// Computes the distances between two collections of instances. - fn many_to_many(&self, i: &[usize], j: &[usize]) -> Vec> { - let a = i.iter().map(|&i| (i, self.get(i))).collect::>(); - let b = j.iter().map(|&i| (i, self.get(i))).collect::>(); - MetricSpace::many_to_many(self, &a, &b) + /// Computes the distance between a query and an item. + fn query_to_one>(&self, query: &I, i: usize, metric: &M) -> T { + metric.distance(query, self.get(i)) } - /// Computes the distances between all given pairs of instances. - fn pairs(&self, pairs: &[(usize, usize)]) -> Vec<(usize, usize, U)> { - let pairs = pairs - .iter() - .map(|&(i, j)| (i, self.get(i), j, self.get(j))) - .collect::>(); - MetricSpace::pairs(self, &pairs) + /// Computes the distances between a query and a collection of items. + /// + /// Each tuple `(i, d)` represents the distance between the query and the + /// item at index `i`. + fn query_to_many>( + &self, + query: &I, + is: &[usize], + metric: &M, + ) -> impl Iterator { + is.iter().map(move |&i| (i, self.query_to_one(query, i, metric))) + } + + /// Computes the distances between two collections of items. + /// + /// Each triplet `(i, j, d)` represents the distance between the items at + /// indices `i` and `j`. + fn many_to_many>( + &self, + is: &[usize], + js: &[usize], + metric: &M, + ) -> impl Iterator> { + is.iter() + .map(|&i| js.iter().map(|&j| (i, j, self.one_to_one(i, j, metric))).collect()) + } + + /// Computes the distances between all given pairs of items. + /// + /// Each triplet `(i, j, d)` represents the distance between the items at + /// indices `i` and `j`. + fn pairs>( + &self, + pairs: &[(usize, usize)], + metric: &M, + ) -> impl Iterator { + pairs.iter().map(|&(i, j)| (i, j, self.one_to_one(i, j, metric))) } - /// Computes the distances between all pairs of instances. - fn pairwise(&self, i: &[usize]) -> Vec> { - let pairs = i.iter().map(|&i| (i, self.get(i))).collect::>(); - MetricSpace::pairwise(self, &pairs) - } + /// Computes the distances between all pairs of items. + /// + /// Each triplet `(i, j, d)` represents the distance between the items at + /// indices `i` and `j`. + fn pairwise>(&self, is: &[usize], metric: &M) -> Vec> { + if metric.has_symmetry() { + let mut matrix = is + .iter() + .map(|&i| is.iter().map(move |&j| (i, j, T::ZERO)).collect::>()) + .collect::>(); + + for (i, &p) in is.iter().enumerate() { + let pairs = is.iter().skip(i + 1).map(|&q| (p, q)).collect::>(); + self.pairs(&pairs, metric) + .enumerate() + .map(|(j, d)| (j + i + 1, d)) + .for_each(|(j, (p, q, d))| { + matrix[i][j] = (p, q, d); + matrix[j][i] = (q, p, d); + }); + } - /// Chooses a subset of instances that are unique. - fn choose_unique(&self, indices: &[usize], choose: usize, seed: Option) -> Vec { - let instances = indices.iter().map(|&i| (i, self.get(i))).collect::>(); - let unique = MetricSpace::choose_unique(self, &instances, choose, seed); - unique.iter().map(|&(i, _)| i).collect() - } + if !metric.has_identity() { + // compute the diagonal for non-metrics + let pairs = is.iter().map(|&p| (p, p)).collect::>(); + self.pairs(&pairs, metric) + .enumerate() + .for_each(|(i, (p, q, d))| matrix[i][i] = (p, q, d)); + } - /// Calculates the geometric median of the given instances. - fn median(&self, indices: &[usize]) -> usize { - let instances = indices.iter().map(|&i| (i, self.get(i))).collect::>(); - MetricSpace::median(self, &instances).0 + matrix + } else { + self.many_to_many(is, is, metric).collect() + } } - /// Runs linear KNN search on the dataset. - fn knn(&self, query: &I, k: usize) -> Vec<(usize, U)> { - let indices = (0..self.cardinality()).collect::>(); - let mut knn = SizedHeap::new(Some(k)); - self.query_to_many(query, &indices) - .into_iter() - .for_each(|(i, d)| knn.push((d, i))); - knn.items().map(|(d, i)| (i, d)).collect() + /// Chooses a subset of items that are unique. + /// + /// If the metric has an identity, the first `choose` unique items, i.e. + /// items that are not equal to any other item, are chosen. Otherwise, a + /// random subset is chosen. + fn choose_unique>( + &self, + is: &[usize], + choose: usize, + seed: Option, + metric: &M, + ) -> Vec { + let mut rng = seed.map_or_else(StdRng::from_entropy, StdRng::seed_from_u64); + + if metric.has_identity() { + let mut choices = Vec::with_capacity(choose); + for (i, a) in is.iter().map(|&i| (i, self.get(i))) { + if !choices.iter().any(|&(_, b)| metric.is_equal(a, b)) { + choices.push((i, a)); + } + + if choices.len() == choose { + break; + } + } + choices.into_iter().map(|(i, _)| i).collect() + } else { + let mut is = is.to_vec(); + is.shuffle(&mut rng); + is.truncate(choose); + is + } } - /// Runs linear RNN search on the dataset. - fn rnn(&self, query: &I, radius: U) -> Vec<(usize, U)> { - let indices = (0..self.cardinality()).collect::>(); - Dataset::query_to_many(self, query, &indices) + /// Calculates the geometric median of the given items. + /// + /// The geometric median is the item that minimizes the sum of distances + /// to all other items. + fn median>(&self, is: &[usize], metric: &M) -> usize { + let (arg_median, _) = self + .pairwise(is, metric) .into_iter() - .filter(|&(_, d)| d <= radius) - .collect() + .map(|row| row.into_iter().map(|(_, _, d)| d).sum::()) + .enumerate() + .fold( + (0, T::MAX), + |(arg_min, min_sum), (i, sum)| { + if sum < min_sum { + (i, sum) + } else { + (arg_min, min_sum) + } + }, + ); + is[arg_median] } } -/// An extension of `Dataset` that provides parallel implementations of -/// distance calculations. +/// Parallel version of [`Dataset`](crate::core::dataset::Dataset). #[allow(clippy::module_name_repetitions)] -pub trait ParDataset: Dataset + ParMetricSpace { - /// Parallel version of `one_to_one`. - fn par_one_to_many(&self, i: usize, j: &[usize]) -> Vec<(usize, U)> { - let a = self.get(i); - let b = j.iter().map(|&i| (i, self.get(i))).collect::>(); - ParMetricSpace::par_one_to_many(self, a, &b) - } - - /// Parallel version of `query_to_one`. - fn par_query_to_many(&self, query: &I, j: &[usize]) -> Vec<(usize, U)> { - let b = j.iter().map(|&i| (i, self.get(i))).collect::>(); - ParMetricSpace::par_one_to_many(self, query, &b) - } - - /// Parallel version of `many_to_many`. - fn par_many_to_many(&self, i: &[usize], j: &[usize]) -> Vec> { - let a = i.iter().map(|&i| (i, self.get(i))).collect::>(); - let b = j.iter().map(|&i| (i, self.get(i))).collect::>(); - ParMetricSpace::par_many_to_many(self, &a, &b) - } - - /// Parallel version of `pairs`. - fn par_pairs(&self, pairs: &[(usize, usize)]) -> Vec<(usize, usize, U)> { - let pairs = pairs - .iter() - .map(|&(i, j)| (i, self.get(i), j, self.get(j))) - .collect::>(); - ParMetricSpace::par_pairs(self, &pairs) - } - - /// Parallel version of `pairwise`. - fn par_pairwise(&self, i: &[usize]) -> Vec> { - let pairs = i.iter().map(|&i| (i, self.get(i))).collect::>(); - ParMetricSpace::par_pairwise(self, &pairs) +pub trait ParDataset: Dataset + Send + Sync { + /// Parallel version of [`Dataset::one_to_one`](crate::core::dataset::Dataset::one_to_one). + fn par_one_to_one>(&self, i: usize, j: usize, metric: &M) -> T { + metric.par_distance(self.get(i), self.get(j)) } - /// Parallel version of `choose_unique`. - fn par_choose_unique(&self, indices: &[usize], choose: usize, seed: Option) -> Vec { - let instances = indices.iter().map(|&i| (i, self.get(i))).collect::>(); - let unique = ParMetricSpace::par_choose_unique(self, &instances, choose, seed); - unique.iter().map(|&(i, _)| i).collect() + /// Parallel version of [`Dataset::one_to_many`](crate::core::dataset::Dataset::one_to_many). + fn par_one_to_many>( + &self, + i: usize, + js: &[usize], + metric: &M, + ) -> impl ParallelIterator { + js.par_iter().map(move |&j| (j, self.par_one_to_one(i, j, metric))) } - /// Parallel version of `median`. - fn par_median(&self, indices: &[usize]) -> usize { - let instances = indices.iter().map(|&i| (i, self.get(i))).collect::>(); - ParMetricSpace::par_median(self, &instances).0 + /// Computes the distance between a query and an item. + fn par_query_to_one>(&self, query: &I, i: usize, metric: &M) -> T { + metric.par_distance(query, self.get(i)) } - /// Runs linear KNN search on the dataset, in parallel. - fn par_knn(&self, query: &I, k: usize) -> Vec<(usize, U)> { - let indices = (0..self.cardinality()).collect::>(); - let mut knn = SizedHeap::new(Some(k)); - self.par_query_to_many(query, &indices) - .into_iter() - .for_each(|(i, d)| knn.push((d, i))); - knn.items().map(|(d, i)| (i, d)).collect() - } - - /// Runs linear RNN search on the dataset, in parallel. - fn par_rnn(&self, query: &I, radius: U) -> Vec<(usize, U)> { - let indices = (0..self.cardinality()).collect::>(); - self.par_query_to_many(query, &indices) - .into_iter() - .filter(|&(_, d)| d <= radius) - .collect() - } -} - -/// A helper struct for maintaining a max heap of a fixed size. -/// -/// This is useful for maintaining the `k` nearest neighbors in a search algorithm. -pub struct SizedHeap { - /// The heap of items. - heap: std::collections::BinaryHeap>, - /// The maximum size of the heap. - k: usize, -} - -impl SizedHeap { - /// Creates a new `SizedMaxHeap` with a fixed size. - #[must_use] - pub fn new(k: Option) -> Self { - k.map_or_else( - || Self { - heap: std::collections::BinaryHeap::new(), - k: usize::MAX, - }, - |k| Self { - heap: std::collections::BinaryHeap::with_capacity(k), - k, - }, - ) - } - - /// Returns the maximum size of the heap. - #[must_use] - pub const fn k(&self) -> usize { - self.k - } - - /// Pushes an item onto the heap, maintaining the max size. - pub fn push(&mut self, item: T) { - if self.heap.len() < self.k { - self.heap.push(MaxItem(item)); - } else if let Some(top) = self.heap.peek() { - if item < top.0 { - self.heap.pop(); - self.heap.push(MaxItem(item)); + /// Computes the distances between a query and a collection of items. + /// + /// Each tuple `(i, d)` represents the distance between the query and the + /// item at index `i`. + fn par_query_to_many>( + &self, + query: &I, + is: &[usize], + metric: &M, + ) -> impl ParallelIterator { + is.par_iter() + .map(move |&i| (i, self.par_query_to_one(query, i, metric))) + } + + /// Parallel version of [`Dataset::many_to_many`](crate::core::dataset::Dataset::many_to_many). + fn par_many_to_many>( + &self, + is: &[usize], + js: &[usize], + metric: &M, + ) -> impl ParallelIterator> { + is.par_iter().map(|&i| { + js.par_iter() + .map(|&j| (i, j, self.par_one_to_one(i, j, metric))) + .collect() + }) + } + + /// Parallel version of [`Dataset::pairs`](crate::core::dataset::Dataset::pairs). + fn par_pairs>( + &self, + pairs: &[(usize, usize)], + metric: &M, + ) -> impl ParallelIterator { + pairs.par_iter().map(|&(i, j)| (i, j, self.one_to_one(i, j, metric))) + } + + /// Parallel version of [`Dataset::pairwise`](crate::core::dataset::Dataset::pairwise). + fn par_pairwise>(&self, a: &[usize], metric: &M) -> Vec> { + if metric.has_symmetry() { + let mut matrix = a + .iter() + .map(|&i| a.iter().map(move |&j| (i, j, T::ZERO)).collect::>()) + .collect::>(); + + for (i, &p) in a.iter().enumerate() { + let pairs = a.iter().skip(i + 1).map(|&q| (p, q)).collect::>(); + let distances = self.par_pairs(&pairs, metric).collect::>(); + distances + .into_iter() + .enumerate() + .map(|(j, d)| (j + i + 1, d)) + .for_each(|(j, (p, q, d))| { + matrix[i][j] = (p, q, d); + matrix[j][i] = (q, p, d); + }); } - } - } - - /// Peeks at the top item in the heap. - #[must_use] - pub fn peek(&self) -> Option<&T> { - self.heap.peek().map(|MaxItem(x)| x) - } - /// Pops the top item from the heap. - pub fn pop(&mut self) -> Option { - self.heap.pop().map(|MaxItem(x)| x) - } - - /// Consumes the `SizedMaxHeap` and returns the items in an iterator. - pub fn items(self) -> impl Iterator { - self.heap.into_iter().map(|MaxItem(x)| x) - } - - /// Returns the number of items in the heap. - #[must_use] - pub fn len(&self) -> usize { - self.heap.len() - } - - /// Returns whether the heap is empty. - #[must_use] - pub fn is_empty(&self) -> bool { - self.heap.is_empty() - } - - /// Returns whether the heap is full. - #[must_use] - pub fn is_full(&self) -> bool { - self.heap.len() == self.k - } -} - -/// A wrapper struct for implementing `PartialOrd` and `Ord` on a type to use -/// with `SizedMaxHeap`. -struct MaxItem(T); - -impl PartialEq for MaxItem { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -impl Eq for MaxItem {} + if !metric.has_identity() { + // compute the diagonal for non-metrics + let pairs = a.iter().map(|&p| (p, p)).collect::>(); + let distances = self.par_pairs(&pairs, metric).collect::>(); + distances + .into_iter() + .enumerate() + .for_each(|(i, (p, q, d))| matrix[i][i] = (p, q, d)); + } -impl PartialOrd for MaxItem { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + matrix + } else { + self.par_many_to_many(a, a, metric).collect() + } } -} -impl Ord for MaxItem { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.0.partial_cmp(&other.0).unwrap_or(core::cmp::Ordering::Greater) + /// Parallel version of [`Dataset::median`](crate::core::dataset::Dataset::median). + fn par_median>(&self, is: &[usize], metric: &M) -> usize { + let (arg_median, _) = self + .par_pairwise(is, metric) + .into_iter() + .map(|row| row.into_iter().map(|(_, _, d)| d).sum::()) + .enumerate() + .fold( + (0, T::MAX), + |(arg_min, min_sum), (i, sum)| { + if sum < min_sum { + (i, sum) + } else { + (arg_min, min_sum) + } + }, + ); + is[arg_median] } } diff --git a/crates/abd-clam/src/core/dataset/permutable.rs b/crates/abd-clam/src/core/dataset/permutable.rs index 107200cda..7669a5335 100644 --- a/crates/abd-clam/src/core/dataset/permutable.rs +++ b/crates/abd-clam/src/core/dataset/permutable.rs @@ -6,11 +6,11 @@ /// the `CAKES` paper. /// /// We may *not* want to permute the dataset in-place, e.g. for use with -/// `CHAODA` because it needs to deal with a given set of instances under -/// multiple metrics. +/// `CHAODA` because it needs to deal with a given set of items under multiple +/// metrics. pub trait Permutable { /// Gets the current permutation of the collection, i.e. the ordering of the - /// original instances into the current order. + /// original items into the current order. /// /// Our implementation of this method on `Vec` and `&mut [T]` will always /// return the identity permutation. @@ -19,7 +19,7 @@ pub trait Permutable { /// Sets the permutation of the collection without modifying the collection. fn set_permutation(&mut self, permutation: &[usize]); - /// Swaps the location of two instances in the collection. + /// Swaps the location of two items in the collection. /// /// # Arguments /// diff --git a/crates/abd-clam/src/core/dataset/sized_heap.rs b/crates/abd-clam/src/core/dataset/sized_heap.rs new file mode 100644 index 000000000..423c646a3 --- /dev/null +++ b/crates/abd-clam/src/core/dataset/sized_heap.rs @@ -0,0 +1,154 @@ +//! A helper struct for maintaining a max heap of an optionally fixed size. + +use rayon::prelude::*; + +/// A helper struct for maintaining a max heap of a fixed size. +/// +/// This is useful for maintaining the `k` nearest neighbors in a search algorithm. +pub struct SizedHeap { + /// The heap of items. + heap: std::collections::BinaryHeap>, + /// The maximum size of the heap. + k: usize, +} + +impl FromIterator for SizedHeap { + fn from_iter>(iter: I) -> Self { + let mut heap = Self::new(None); + for item in iter { + heap.push(item); + } + heap + } +} + +impl SizedHeap { + /// Creates a new `SizedHeap` with a fixed size. + #[must_use] + pub fn new(k: Option) -> Self { + k.map_or_else( + || Self { + heap: std::collections::BinaryHeap::new(), + k: usize::MAX, + }, + |k| Self { + heap: std::collections::BinaryHeap::with_capacity(k), + k, + }, + ) + } + + /// Returns the maximum size of the heap. + #[must_use] + pub const fn k(&self) -> usize { + self.k + } + + /// Pushes an item onto the heap, maintaining the max size. + pub fn push(&mut self, item: T) { + if self.heap.len() < self.k { + self.heap.push(MaxItem(item)); + } else if let Some(top) = self.heap.peek() { + if item < top.0 { + self.heap.pop(); + self.heap.push(MaxItem(item)); + } + } + } + + /// Pushes several items onto the heap, maintaining the max size. + pub fn extend>(&mut self, items: I) { + for item in items { + self.heap.push(MaxItem(item)); + } + while self.heap.len() > self.k { + self.heap.pop(); + } + } + + /// Peeks at the top item in the heap. + #[must_use] + pub fn peek(&self) -> Option<&T> { + self.heap.peek().map(|MaxItem(x)| x) + } + + /// Pops the top item from the heap. + pub fn pop(&mut self) -> Option { + self.heap.pop().map(|MaxItem(x)| x) + } + + /// Consumes the `SizedHeap` and returns the items in an iterator. + pub fn items(self) -> impl Iterator { + self.heap.into_iter().map(|MaxItem(x)| x) + } + + /// Returns the number of items in the heap. + #[must_use] + pub fn len(&self) -> usize { + self.heap.len() + } + + /// Returns whether the heap is empty. + #[must_use] + pub fn is_empty(&self) -> bool { + self.heap.is_empty() + } + + /// Returns whether the heap is full. + #[must_use] + pub fn is_full(&self) -> bool { + self.heap.len() == self.k + } + + /// Merge two heaps into one. + pub fn merge(&mut self, other: Self) { + self.extend(other.items()); + } + + /// Retains only the elements that satisfy the predicate. + pub fn retain bool>(&mut self, f: F) { + self.heap.retain(|MaxItem(x)| f(x)); + } +} + +impl SizedHeap { + /// Pushes several items onto the heap, maintaining the max size. + pub fn par_extend>(&mut self, items: I) { + for item in items.collect::>() { + self.heap.push(MaxItem(item)); + } + while self.heap.len() > self.k { + self.heap.pop(); + } + } + + /// Parallel version of [`SizedHeap::items`](crate::core::dataset::SizedHeap::items). + #[must_use] + pub fn par_items(self) -> impl ParallelIterator { + self.heap.into_par_iter().map(|MaxItem(x)| x) + } +} + +/// A wrapper struct for implementing `PartialOrd` and `Ord` on a type to use +/// with `SizedHeap`. +struct MaxItem(T); + +impl PartialEq for MaxItem { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for MaxItem {} + +impl PartialOrd for MaxItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MaxItem { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.0.partial_cmp(&other.0).unwrap_or(core::cmp::Ordering::Less) + } +} diff --git a/crates/abd-clam/src/core/metric/absolute_difference.rs b/crates/abd-clam/src/core/metric/absolute_difference.rs new file mode 100644 index 000000000..262c43aee --- /dev/null +++ b/crates/abd-clam/src/core/metric/absolute_difference.rs @@ -0,0 +1,41 @@ +//! The `AbsoluteDifference` metric. + +use distances::Number; + +use super::{Metric, ParMetric}; + +/// The `AbsoluteDifference` metric measures the absolute difference between two +/// values. It is meant to be used with scalars. +pub struct AbsoluteDifference; + +impl Metric for AbsoluteDifference { + fn distance(&self, a: &T, b: &T) -> T { + a.abs_diff(*b) + } + + fn name(&self) -> &str { + "absolute-difference" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl ParMetric for AbsoluteDifference {} diff --git a/crates/abd-clam/src/core/metric/cosine.rs b/crates/abd-clam/src/core/metric/cosine.rs new file mode 100644 index 000000000..39167f33f --- /dev/null +++ b/crates/abd-clam/src/core/metric/cosine.rs @@ -0,0 +1,40 @@ +//! The Cosine distance function. + +use distances::number::Float; + +use super::{Metric, ParMetric}; + +/// The Cosine distance function. +pub struct Cosine; + +impl, T: Float> Metric for Cosine { + fn distance(&self, a: &I, b: &I) -> T { + distances::vectors::cosine(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "cosine" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl + Send + Sync, U: Float> ParMetric for Cosine {} diff --git a/crates/abd-clam/src/core/metric/euclidean.rs b/crates/abd-clam/src/core/metric/euclidean.rs new file mode 100644 index 000000000..40ba18291 --- /dev/null +++ b/crates/abd-clam/src/core/metric/euclidean.rs @@ -0,0 +1,40 @@ +//! The `Euclidean` distance metric. + +use distances::number::Float; + +use super::{Metric, ParMetric}; + +/// The `Euclidean` distance metric. +pub struct Euclidean; + +impl, T: Float> Metric for Euclidean { + fn distance(&self, a: &I, b: &I) -> T { + distances::vectors::euclidean(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "euclidean" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl + Send + Sync, U: Float> ParMetric for Euclidean {} diff --git a/crates/abd-clam/src/core/metric/hypotenuse.rs b/crates/abd-clam/src/core/metric/hypotenuse.rs new file mode 100644 index 000000000..7631b17b7 --- /dev/null +++ b/crates/abd-clam/src/core/metric/hypotenuse.rs @@ -0,0 +1,43 @@ +//! The `Hypotenuse` metric. + +use distances::{number::Float, Number}; + +use super::{Metric, ParMetric}; + +/// The `Hypotenuse` is just the `Euclidean` distance between two points in 2D +/// space. +pub struct Hypotenuse; + +impl Metric<(T, T), U> for Hypotenuse { + fn distance(&self, a: &(T, T), b: &(T, T)) -> U { + let a = [a.0, a.1]; + let b = [b.0, b.1]; + distances::vectors::euclidean(&a, &b) + } + + fn name(&self) -> &str { + "hypotenuse" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl ParMetric<(T, T), U> for Hypotenuse {} diff --git a/crates/abd-clam/src/core/metric/levenshtein.rs b/crates/abd-clam/src/core/metric/levenshtein.rs new file mode 100644 index 000000000..71727c4dc --- /dev/null +++ b/crates/abd-clam/src/core/metric/levenshtein.rs @@ -0,0 +1,40 @@ +//! The `Levenshtein` edit distance metric. + +use distances::number::Int; + +use super::{Metric, ParMetric}; + +/// The `Levenshtein` edit distance metric. +pub struct Levenshtein; + +impl, T: Int> Metric for Levenshtein { + fn distance(&self, a: &I, b: &I) -> T { + T::from(stringzilla::sz::edit_distance(a.as_ref(), b.as_ref())) + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl + Send + Sync, U: Int> ParMetric for Levenshtein {} diff --git a/crates/abd-clam/src/core/metric/manhattan.rs b/crates/abd-clam/src/core/metric/manhattan.rs new file mode 100644 index 000000000..a4d981fe4 --- /dev/null +++ b/crates/abd-clam/src/core/metric/manhattan.rs @@ -0,0 +1,43 @@ +//! The `Manhattan` distance metric. + +use distances::Number; + +use super::{Metric, ParMetric}; + +/// The `Manhattan` distance metric, also known as the city block distance. +/// +/// This is a distance metric that measures the distance between two points in a +/// grid based on the sum of the absolute differences of their coordinates. +pub struct Manhattan; + +impl, T: Number> Metric for Manhattan { + fn distance(&self, a: &I, b: &I) -> T { + distances::vectors::manhattan(a.as_ref(), b.as_ref()) + } + + fn name(&self) -> &str { + "manhattan" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl + Send + Sync, T: Number> ParMetric for Manhattan {} diff --git a/crates/abd-clam/src/core/metric/mod.rs b/crates/abd-clam/src/core/metric/mod.rs new file mode 100644 index 000000000..1befdd53b --- /dev/null +++ b/crates/abd-clam/src/core/metric/mod.rs @@ -0,0 +1,227 @@ +//! The `Metric` trait is used for all distance computations in CLAM. + +use distances::Number; + +mod absolute_difference; +mod cosine; +mod euclidean; +mod hypotenuse; +mod manhattan; + +pub use absolute_difference::AbsoluteDifference; +pub use cosine::Cosine; +pub use euclidean::Euclidean; +pub use hypotenuse::Hypotenuse; +pub use manhattan::Manhattan; + +#[cfg(feature = "msa")] +mod levenshtein; + +#[cfg(feature = "msa")] +pub use levenshtein::Levenshtein; + +/// The `Metric` trait is used for all distance computations in CLAM. +/// +/// # Type Parameters +/// +/// - `I`: The type of the items. +/// - `T`: The type of the distance values. +/// +/// # Example +/// +/// The following is an example of a `Metric` implementation for the Hamming +/// distance between two sequences of bytes. This implementation short-circuits +/// on the length of the shorter sequence. +/// +/// ```rust +/// use abd_clam::metric::{Metric, ParMetric}; +/// +/// struct Hamming; +/// +/// // `I` is any type that can be dereferenced to a slice of bytes. +/// impl> Metric for Hamming { +/// fn distance(&self, a: &I, b: &I) -> usize { +/// // Count the number of positions where the elements are different. +/// a.as_ref().iter().zip(b.as_ref()).filter(|(x, y)| x != y).count() +/// } +/// +/// fn name(&self) -> &str { +/// "hamming" +/// } +/// +/// fn has_identity(&self) -> bool { +/// // Two sequences are identical if all of their elements are equal. +/// true +/// } +/// +/// fn has_non_negativity(&self) -> bool { +/// // The `usize` type is always non-negative. +/// true +/// } +/// +/// fn has_symmetry(&self) -> bool { +/// // The Hamming distance is symmetric. +/// true +/// } +/// +/// fn obeys_triangle_inequality(&self) -> bool { +/// // The Hamming distance satisfies the triangle inequality. +/// true +/// } +/// +/// fn is_expensive(&self) -> bool { +/// // The Hamming distance is not expensive to compute because it is +/// // linear in the length of the sequences. +/// false +/// } +/// } +/// +/// // There is no real benefit to parallelizing the counting of different +/// // elements in the Hamming distance. +/// impl + Send + Sync> ParMetric for Hamming {} +/// +/// // Test the Hamming distance. +/// let a = b"hello"; +/// let b = b"world"; +/// let metric = Hamming; +/// +/// assert_eq!(metric.distance(&a, &b), 4); +/// assert_eq!(metric.par_distance(&a, &b), 4); +/// ``` +pub trait Metric { + /// Call the metric on two items. + fn distance(&self, a: &I, b: &I) -> T; + + /// The name of the metric. + fn name(&self) -> &str; + + /// Whether the metric provides an identity among the items. + /// + /// Identity is defined as `d(a, b) = 0` if and only if `a = b`. + /// + /// This is used when computing the diagonal of a pairwise distance matrix. + fn has_identity(&self) -> bool; + + /// Whether the metric only produces non-negative values. + /// + /// Non-negativity is defined as `d(a, b) >= 0` for all items `a` and `b`. + /// + /// This is the most important property of metrics for use in CLAM. + fn has_non_negativity(&self) -> bool; + + /// Whether the metric is symmetric. + /// + /// Symmetry is defined as `d(a, b) = d(b, a)` for all items `a` and `b`. + /// + /// This is used when computing the lower triangle of a pairwise distance + /// matrix. + fn has_symmetry(&self) -> bool; + + /// Whether the metric satisfies the triangle inequality. + /// + /// The triangle inequality is defined as `d(a, b) + d(b, c) >= d(a, c)` for + /// all items `a`, `b`, and `c`. + /// + /// If the distance function satisfies the triangle inequality, then the + /// search results from CAKES will have perfect recall. + fn obeys_triangle_inequality(&self) -> bool; + + /// Whether the metric is expensive to compute. + /// + /// We say that a metric is expensive if it costs more than linear time in + /// the size of the items to compute the distance between two items. + /// + /// When using expensive metrics, we use slightly different parallelism in + /// CLAM. + fn is_expensive(&self) -> bool; + + /// Whether an item is equal to another item. Items can only be equal if the + /// metric provides an identity. + /// + /// This is a convenience function that checks if the distance between two + /// items is zero. + fn is_equal(&self, a: &I, b: &I) -> bool { + self.has_identity() && self.distance(a, b) == T::ZERO + } +} + +/// Parallel version of [`Metric`](crate::core::metric::Metric). +#[allow(clippy::module_name_repetitions)] +pub trait ParMetric: Metric + Send + Sync { + /// Parallel version of [`Metric::distance`](crate::core::metric::Metric::distance). + /// + /// The default implementation calls the non-parallel version of the + /// distance function. + /// + /// This may be used when the distance function itself can be computed with + /// some parallelism. + fn par_distance(&self, a: &I, b: &I) -> T { + self.distance(a, b) + } +} + +impl Metric for Box> { + fn distance(&self, a: &I, b: &I) -> T { + (**self).distance(a, b) + } + + fn name(&self) -> &str { + (**self).name() + } + + fn has_identity(&self) -> bool { + (**self).has_identity() + } + + fn has_non_negativity(&self) -> bool { + (**self).has_non_negativity() + } + + fn has_symmetry(&self) -> bool { + (**self).has_symmetry() + } + + fn obeys_triangle_inequality(&self) -> bool { + (**self).obeys_triangle_inequality() + } + + fn is_expensive(&self) -> bool { + (**self).is_expensive() + } +} + +impl Metric for Box> { + fn distance(&self, a: &I, b: &I) -> T { + (**self).distance(a, b) + } + + fn name(&self) -> &str { + (**self).name() + } + + fn has_identity(&self) -> bool { + (**self).has_identity() + } + + fn has_non_negativity(&self) -> bool { + (**self).has_non_negativity() + } + + fn has_symmetry(&self) -> bool { + (**self).has_symmetry() + } + + fn obeys_triangle_inequality(&self) -> bool { + (**self).obeys_triangle_inequality() + } + + fn is_expensive(&self) -> bool { + (**self).is_expensive() + } +} + +impl ParMetric for Box> { + fn par_distance(&self, a: &I, b: &I) -> T { + (**self).par_distance(a, b) + } +} diff --git a/crates/abd-clam/src/core/mod.rs b/crates/abd-clam/src/core/mod.rs index 077c268cc..24df70057 100644 --- a/crates/abd-clam/src/core/mod.rs +++ b/crates/abd-clam/src/core/mod.rs @@ -2,6 +2,10 @@ pub mod cluster; pub mod dataset; +pub mod metric; +mod tree; -pub use cluster::{adapter, partition, BalancedBall, Ball, Cluster, Partition, LFD}; -pub use dataset::{Dataset, FlatVec, Metric, MetricSpace, Permutable}; +pub use cluster::{Ball, Cluster, LFD}; +pub use dataset::{Dataset, FlatVec, SizedHeap}; +pub use metric::Metric; +pub use tree::Tree; diff --git a/crates/abd-clam/src/core/tree/mod.rs b/crates/abd-clam/src/core/tree/mod.rs new file mode 100644 index 000000000..4af029715 --- /dev/null +++ b/crates/abd-clam/src/core/tree/mod.rs @@ -0,0 +1,259 @@ +//! A `Tree` of `Cluster`s. + +use distances::Number; +use rayon::prelude::*; + +use super::{ + cluster::{ParPartition, Partition}, + dataset::ParDataset, + metric::ParMetric, + Cluster, Dataset, Metric, +}; + +/// A `Tree` of `Cluster`s. +pub struct Tree> { + /// The `Cluster`s of the `Tree` are stored in `levels` where the first + /// `Vec` contains the `Cluster`s at the first level, the second `Vec` + /// contains the `Cluster`s at the second level, and so on. + /// + /// Each `Cluster` is represented by a tuple `(node, a, b, c, d)` where + /// `node` is the `Cluster` and `a` and `b` define the range of indices of + /// the children of the `Cluster` in the next level. + levels: Vec>, + /// The diameter of the root `Cluster` in the `Tree`. + diameter: T, +} + +impl> From for Tree { + fn from(mut c: C) -> Self { + let diameter = c.radius().double(); + if c.is_leaf() { + Self { + levels: vec![vec![(c, 0, 0)]], + diameter, + } + } else { + let mut child_trees = c + .take_children() + .into_iter() + .map(|child| Self::from(*child)) + .collect::>(); + let n_children = child_trees.len(); + + let rest = child_trees.split_off(1); + let first = child_trees + .pop() + .unwrap_or_else(|| unreachable!("We checked that the `Cluster` is not a leaf.")); + let mut tree = rest.into_iter().fold(first, Self::merge); + + tree.levels.insert(0, vec![(c, 0, n_children)]); + tree.diameter = diameter; + + tree + } + } +} + +impl> Tree { + /// Returns the root `Cluster` of the `Tree`. + pub fn root(&self) -> (&C, usize, usize) { + self.levels + .first() + .and_then(|level| level.first().map(|(node, a, b)| (node, *a, *b))) + .unwrap_or_else(|| unreachable!("The `Tree` is empty.")) + } + + /// Returns the `Cluster` at the given `depth` and `index`. + pub fn get(&self, depth: usize, index: usize) -> Option<&C> { + self.levels + .get(depth) + .and_then(|level| level.get(index).map(|(node, _, _)| node)) + } + + /// Finds the `Cluster` in the `Tree` that is equal to the given `Cluster`. + /// + /// If the `Cluster` is found, the method returns: + /// + /// - The depth of the `Cluster` in the `Tree`. + /// - The index of the `Cluster` in its level the `Tree`. + /// - The start index of the children of the `Cluster` in the next level. + /// - The end index of the children of the `Cluster` in the next level. + pub fn find(&self, c: &C) -> Option<(usize, usize, usize, usize)> { + self.levels + .get(c.depth()) + .and_then(|level| { + let pos = level.iter().position(|(node, _, _)| node == c); + pos.map(|index| (index, &level[index])) + }) + .map(|(index, (_, a, b))| (c.depth(), index, *a, *b)) + } + + /// Returns the the children of the `Cluster` at the given `depth` and + /// `index`. + pub fn children_of(&self, depth: usize, index: usize) -> Vec<(&C, usize, usize)> { + self.levels + .get(depth) + .and_then(|level| { + level.get(index).and_then(|&(_, a, b)| { + self.levels + .get(depth + 1) + .map(|level| level[a..b].iter().map(|(node, a, b)| (node, *a, *b)).collect()) + }) + }) + .unwrap_or_default() + } + + /// Returns the diameter of the `Tree`. + pub const fn diameter(&self) -> T { + self.diameter + } + + /// Returns the levels of the `Tree`. + pub fn levels(&self) -> &[Vec<(C, usize, usize)>] { + &self.levels + } + + /// Returns the clusters in the tree in breadth-first order. + pub fn bft(&self) -> impl Iterator { + self.levels + .iter() + .flat_map(|level| level.iter().map(|(node, _, _)| node)) + } + + /// Merges the `Tree` with an other `Tree`, consuming both `Tree`s and + /// keeping the diameter of first `Tree`. + fn merge(mut self, mut other: Self) -> Self { + let mut s_levels = core::mem::take(&mut self.levels); + let mut o_levels = core::mem::take(&mut other.levels); + + let zipped_levels = s_levels.iter_mut().skip(1).zip(o_levels.iter_mut()).collect::>(); + for (s_level, o_level) in zipped_levels.into_iter().rev() { + let n = s_level.len(); + for (_, a, b) in o_level { + *a += n; + *b += n; + } + } + + let remaining_levels = if s_levels.len() < o_levels.len() { + o_levels.split_off(s_levels.len()) + } else { + s_levels.split_off(o_levels.len()) + }; + self.levels = s_levels + .into_iter() + .zip(o_levels) + .map(|(mut s_level, mut o_level)| { + s_level.append(&mut o_level); + s_level + }) + .chain(remaining_levels) + .collect(); + + self + } +} + +impl> Tree { + /// Constructs a `Tree` from the given `data` and `metric`. + /// + /// # Errors + /// + /// Any error from `C::new` is propagated. + pub fn new, M: Metric, F: Fn(&C) -> bool>( + data: &D, + metric: &M, + criteria: &F, + seed: Option, + ) -> Result { + let indices = (0..data.cardinality()).collect::>(); + let root = C::new(data, metric, &indices, 0, seed)?; + let diameter = root.radius().double(); + + let mut levels = vec![vec![(root, 0, 0)]]; + loop { + let mut last_level = levels + .last_mut() + .unwrap_or_else(|| unreachable!("We inserted the root.")) + .iter_mut() + .filter_map(|(node, a, b)| if criteria(node) { Some((node, a, b)) } else { None }) + .collect::>(); + + let children = last_level + .iter_mut() + .map(|(node, _, _)| { + node.partition_once(data, metric, seed) + .into_iter() + .map(|child| *child) + .collect::>() + }) + .collect::>(); + + if children.is_empty() { + break; + } + + let mut next_level = Vec::new(); + for ((_, a, b), children) in last_level.into_iter().zip(children) { + *a = next_level.len(); + *b = next_level.len() + children.len(); + next_level.extend(children.into_iter().map(|child| (child, 0, 0))); + } + levels.push(next_level); + } + + Ok(Self { levels, diameter }) + } +} + +impl> Tree { + /// Parallel version of [`Tree::new`](crate::core::tree::Tree::new). + /// + /// # Errors + /// + /// Any error from `C::par_new` is propagated. + pub fn par_new, M: ParMetric, F: (Fn(&C) -> bool) + Send + Sync>( + data: &D, + metric: &M, + criteria: &F, + seed: Option, + ) -> Result { + let indices = (0..data.cardinality()).collect::>(); + let root = C::par_new(data, metric, &indices, 0, seed)?; + let diameter = root.radius().double(); + + let mut levels = vec![vec![(root, 0, 0)]]; + loop { + let mut last_level = levels + .last_mut() + .unwrap_or_else(|| unreachable!("We inserted the root.")) + .into_par_iter() + .filter_map(|(node, a, b)| if criteria(node) { Some((node, a, b)) } else { None }) + .collect::>(); + + let children = last_level + .par_iter_mut() + .map(|(node, _, _)| { + node.par_partition_once(data, metric, seed) + .into_iter() + .map(|child| *child) + .collect::>() + }) + .collect::>(); + + if children.is_empty() { + break; + } + + let mut next_level = Vec::new(); + for ((_, a, b), children) in last_level.into_iter().zip(children) { + *a = next_level.len(); + *b = next_level.len() + children.len(); + next_level.extend(children.into_iter().map(|child| (child, 0, 0))); + } + levels.push(next_level); + } + + Ok(Self { levels, diameter }) + } +} diff --git a/crates/abd-clam/src/lib.rs b/crates/abd-clam/src/lib.rs index e9c47047f..f037e5cd5 100644 --- a/crates/abd-clam/src/lib.rs +++ b/crates/abd-clam/src/lib.rs @@ -18,16 +18,19 @@ pub mod cakes; mod core; -pub mod mbed; +pub mod pancakes; pub mod utils; -pub use crate::core::{ - adapter, cluster, dataset, partition, BalancedBall, Ball, Cluster, Dataset, FlatVec, Metric, MetricSpace, Partition, - Permutable, LFD, -}; +pub use core::{cluster, dataset, metric, Ball, Cluster, Dataset, FlatVec, Metric, SizedHeap, Tree, LFD}; #[cfg(feature = "chaoda")] pub mod chaoda; +#[cfg(feature = "mbed")] +pub mod mbed; + +#[cfg(feature = "msa")] +pub mod msa; + /// The current version of the crate. -pub const VERSION: &str = "0.31.0"; +pub const VERSION: &str = "0.32.0"; diff --git a/crates/abd-clam/src/mbed/physics/mass.rs b/crates/abd-clam/src/mbed/physics/mass.rs index eb1978826..45fd87f04 100644 --- a/crates/abd-clam/src/mbed/physics/mass.rs +++ b/crates/abd-clam/src/mbed/physics/mass.rs @@ -3,7 +3,7 @@ use distances::Number; use rand::prelude::*; -use crate::{adapter::Adapter, cakes::OffBall, chaoda::Vertex, Cluster, Dataset}; +use crate::{cakes::PermutedBall, chaoda::Vertex, Cluster}; /// A `Mass` in the mass-spring system for dimensionality reduction. /// @@ -77,13 +77,12 @@ impl Mass { /// This assigns the `position` and `velocity` of the `Mass` to be the zero /// vector, and the `mass` to be the cardinality of the `Cluster`. #[must_use] - pub fn from_vertex(c: &Vertex>) -> Self + pub fn from_vertex(c: &Vertex>) -> Self where - U: Number, - D: Dataset, - C: Cluster, + T: Number, + C: Cluster, { - Self::new(c.arg_center(), c.source().offset(), c.cardinality()) + Self::new(c.arg_center(), c.source.offset(), c.cardinality()) } /// Returns a hash-key for the `Mass`. diff --git a/crates/abd-clam/src/mbed/physics/spring.rs b/crates/abd-clam/src/mbed/physics/spring.rs index 750f76aad..9da8dadb6 100644 --- a/crates/abd-clam/src/mbed/physics/spring.rs +++ b/crates/abd-clam/src/mbed/physics/spring.rs @@ -32,7 +32,7 @@ pub struct Spring<'a, const DIM: usize> { impl<'a, const DIM: usize> Spring<'a, DIM> { /// Create a new `Spring`. - pub fn new(a: &'a Mass, b: &'a Mass, k: f32, l0: U) -> Self { + pub fn new(a: &'a Mass, b: &'a Mass, k: f32, l0: T) -> Self { let mut s = Self { a, b, @@ -99,16 +99,16 @@ impl<'a, const DIM: usize> Spring<'a, DIM> { } } -impl<'a, const DIM: usize> core::hash::Hash for Spring<'a, DIM> { +impl core::hash::Hash for Spring<'_, DIM> { fn hash(&self, state: &mut H) { self.hash_key().hash(state); } } -impl<'a, const DIM: usize> PartialEq for Spring<'a, DIM> { +impl PartialEq for Spring<'_, DIM> { fn eq(&self, other: &Self) -> bool { self.hash_key() == other.hash_key() } } -impl<'a, const DIM: usize> Eq for Spring<'a, DIM> {} +impl Eq for Spring<'_, DIM> {} diff --git a/crates/abd-clam/src/mbed/physics/system.rs b/crates/abd-clam/src/mbed/physics/system.rs index 4b2b543bb..c2709564d 100644 --- a/crates/abd-clam/src/mbed/physics/system.rs +++ b/crates/abd-clam/src/mbed/physics/system.rs @@ -7,10 +7,9 @@ use rand::prelude::*; use rayon::prelude::*; use crate::{ - adapter::Adapter, - cakes::OffBall, + cakes::PermutedBall, chaoda::{Graph, Vertex}, - Cluster, Dataset, FlatVec, Metric, + Cluster, FlatVec, }; use super::{Mass, Spring}; @@ -31,13 +30,12 @@ pub struct System<'a, const DIM: usize> { } /// Get the hash key of a `Vertex` for use in the `System`. -fn c_hash_key(c: &Vertex>) -> (usize, usize) +fn c_hash_key(c: &Vertex>) -> (usize, usize) where - U: Number, - D: Dataset, - C: Cluster, + T: Number, + C: Cluster, { - (c.source().offset(), c.cardinality()) + (c.source.offset(), c.cardinality()) } impl<'a, const DIM: usize> System<'a, DIM> { @@ -58,11 +56,10 @@ impl<'a, const DIM: usize> System<'a, DIM> { /// - `D`: The dataset. /// - `C`: The type of the `Cluster`s in the `Graph`. #[must_use] - pub fn from_graph(g: &Graph>, beta: f32) -> Self + pub fn from_graph(g: &Graph>, beta: f32) -> Self where - U: Number, - D: Dataset, - C: Cluster, + T: Number, + C: Cluster, { let masses = g .iter_clusters() @@ -78,11 +75,10 @@ impl<'a, const DIM: usize> System<'a, DIM> { } /// Resets the `System`'s `Springs` to match the `Graph`. - pub fn reset_springs(&'a mut self, g: &Graph>, k: f32) + pub fn reset_springs(&'a mut self, g: &Graph>, k: f32) where - U: Number, - D: Dataset, - C: Cluster, + T: Number, + C: Cluster, { self.springs = g .iter_edges() @@ -236,7 +232,7 @@ impl<'a, const DIM: usize> System<'a, DIM> { /// represented by the `Mass`es. /// - The distance function is the Euclidean distance. #[must_use] - pub fn get_reduced_embedding(&self) -> FlatVec, f32, usize> { + pub fn get_reduced_embedding(&self) -> FlatVec, usize> { let masses = { let mut masses = self.masses.iter().collect::>(); masses.sort_by(|&(a, _), &(b, _)| a.cmp(b)); @@ -248,10 +244,10 @@ impl<'a, const DIM: usize> System<'a, DIM> { .flat_map(|&(&(_, c), m)| (0..c).map(move |_| m.position().to_vec()).collect::>()) .collect::>(); - let distance_fn = |x: &Vec, y: &Vec| distances::simd::euclidean_f32(x, y); - let metric = Metric::new(distance_fn, false); + // let distance_fn = |x: &Vec, y: &Vec| distances::simd::euclidean_f32(x, y); + // let metric = Metric::new(distance_fn, false); - FlatVec::new(positions, metric).unwrap_or_else(|e| unreachable!("Error creating FlatVec: {e}")) + FlatVec::new(positions).unwrap_or_else(|e| unreachable!("Error creating FlatVec: {e}")) } /// Simulate the `System` until reaching the target stability, saving the @@ -270,7 +266,7 @@ impl<'a, const DIM: usize> System<'a, DIM> { /// # Errors /// /// If there is an error writing the embeddings to disk. - #[cfg(feature = "ndarray-bindings")] + #[cfg(feature = "disk-io")] #[allow(clippy::too_many_arguments)] pub fn evolve_to_stability_with_saves>( &mut self, @@ -291,8 +287,8 @@ impl<'a, const DIM: usize> System<'a, DIM> { while stability < target && i < max_steps { if i % save_every == 0 { ftlog::debug!("{name}: Saving step {}", i + 1); - self.get_reduced_embedding() - .write_npy(&dir, &format!("{}.npy", i + 1))?; + let path = dir.as_ref().join(format!("{}.npy", i + 1)); + self.get_reduced_embedding().write_npy(&path)?; } ftlog::debug!("Step {i}, Stability: {stability:.6}"); @@ -318,7 +314,7 @@ impl<'a, const DIM: usize> System<'a, DIM> { /// # Errors /// /// If there is an error writing the embeddings to disk. - #[cfg(feature = "ndarray-bindings")] + #[cfg(feature = "disk-io")] pub fn evolve_with_saves>( &mut self, dt: f32, @@ -332,8 +328,8 @@ impl<'a, const DIM: usize> System<'a, DIM> { for i in 0..steps { if i % save_every == 0 { ftlog::debug!("{name}: Saving step {}/{steps}", i + 1); - self.get_reduced_embedding() - .write_npy(&dir, &format!("{}.npy", i + 1))?; + let path = dir.as_ref().join(format!("{}.npy", i + 1)); + self.get_reduced_embedding().write_npy(&path)?; } self.update_step(dt); } diff --git a/crates/abd-clam/src/msa/aligner/cost_matrix.rs b/crates/abd-clam/src/msa/aligner/cost_matrix.rs new file mode 100644 index 000000000..d6f3ea606 --- /dev/null +++ b/crates/abd-clam/src/msa/aligner/cost_matrix.rs @@ -0,0 +1,357 @@ +//! Substitution matrix for the Needleman-Wunsch aligner. + +use core::ops::Neg; + +use std::collections::HashSet; + +use distances::{number::Int, Number}; + +use super::super::NUM_CHARS; + +/// A substitution matrix for the Needleman-Wunsch aligner. +#[derive(Clone, Debug)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +pub struct CostMatrix { + /// The cost of substituting one character for another. + sub_matrix: Vec>, + /// The cost to open a gap. + gap_open: T, + /// The cost to extend a gap. + gap_ext: T, +} + +impl Default for CostMatrix { + fn default() -> Self { + Self::new(T::ONE, T::ONE, T::ONE) + } +} + +impl CostMatrix { + /// Create a new substitution matrix. + #[must_use] + pub fn new(default_sub: T, gap_open: T, gap_ext: T) -> Self { + // Initialize the substitution matrix. + let mut sub_matrix = [[default_sub; NUM_CHARS]; NUM_CHARS]; + + // Set the diagonal to zero. + sub_matrix.iter_mut().enumerate().for_each(|(i, row)| row[i] = T::ZERO); + + Self { + sub_matrix: sub_matrix.iter().map(|row| row.to_vec()).collect(), + gap_open, + gap_ext, + } + } + + /// Create a new substitution matrix with affine gap penalties. + /// + /// All substitution costs are set to 1. + #[must_use] + pub fn default_affine(gap_open: Option) -> Self { + let gap_open = gap_open.map_or_else(|| T::from(10), T::from); + Self::new(T::ONE, gap_open, T::ONE) + } + + /// Add a constant to all substitution costs. + #[must_use] + pub fn shift(mut self, shift: T) -> Self { + for i in 0..NUM_CHARS { + for j in 0..NUM_CHARS { + self.sub_matrix[i][j] += shift; + } + } + self + } + + /// Multiply all substitution costs by a constant. + #[must_use] + pub fn scale(mut self, scale: T) -> Self { + for i in 0..NUM_CHARS { + for j in 0..NUM_CHARS { + self.sub_matrix[i][j] *= scale; + } + } + self + } + + /// Set the cost of substituting one character for another. + /// + /// # Arguments + /// + /// * `a`: The old character to be substituted. + /// * `b`: The new character to substitute with. + /// * `cost`: The cost of the substitution. + #[must_use] + pub fn with_sub_cost(mut self, a: u8, b: u8, cost: T) -> Self { + self.sub_matrix[a as usize][b as usize] = cost; + self + } + + /// Set the cost of opening a gap. + #[must_use] + pub const fn with_gap_open(mut self, cost: T) -> Self { + self.gap_open = cost; + self + } + + /// Set the cost of extending a gap. + #[must_use] + pub const fn with_gap_ext(mut self, cost: T) -> Self { + self.gap_ext = cost; + self + } + + /// Get the cost of substituting one character for another. + /// + /// # Arguments + /// + /// * `a`: The old character to be substituted. + /// * `b`: The new character to substitute with. + pub fn sub_cost(&self, a: u8, b: u8) -> T { + self.sub_matrix[a as usize][b as usize] + } + + /// Get the cost of opening a gap. + pub const fn gap_open_cost(&self) -> T { + self.gap_open + } + + /// Get the cost of extending a gap. + pub const fn gap_ext_cost(&self) -> T { + self.gap_ext + } +} + +impl> CostMatrix { + /// Linearly increase all costs in the matrix so that the minimum cost is + /// zero and all non-zero costs are positive. + #[must_use] + pub fn normalize(self) -> Self { + let shift = self + .sub_matrix + .iter() + .flatten() + .fold(T::MAX, |a, &b| if a < b { a } else { b }); + + self.shift(-shift) + } +} + +impl> Neg for CostMatrix { + type Output = Self; + + fn neg(mut self) -> Self::Output { + for i in 0..NUM_CHARS { + for j in 0..NUM_CHARS { + self.sub_matrix[i][j] = -self.sub_matrix[i][j]; + } + } + self.gap_open = -self.gap_open; + self.gap_ext = -self.gap_ext; + self + } +} + +impl> CostMatrix { + /// A substitution matrix for the Needleman-Wunsch aligner using the + /// extendedIUPAC alphabet for nucleotides. + /// + /// See [here](https://www.bioinformatics.org/sms/iupac.html) for an + /// explanation of the IUPAC codes. + /// + /// # Arguments + /// + /// * `gap_open`: The factor by which it is more expensive to open a gap + /// than to extend an existing gap. This defaults to 10. + pub fn extended_iupac(gap_open: Option) -> Self { + let gap_open = gap_open.unwrap_or(10); + + // For each pair of IUPAC characters, the cost is 1 - n / m, where m is + // the number possible pairs of nucleotides that can be represented by + // the IUPAC characters, and n is the number of matching pairs. + #[rustfmt::skip] + let costs = vec![ + ('A', 'R', 1, 2), ('C', 'Y', 1, 2), ('G', 'R', 1, 2), ('T', 'Y', 1, 2), + ('A', 'W', 1, 2), ('C', 'S', 1, 2), ('G', 'S', 1, 2), ('T', 'W', 1, 2), + ('A', 'M', 1, 2), ('C', 'M', 1, 2), ('G', 'K', 1, 2), ('T', 'K', 1, 2), + ('A', 'D', 1, 3), ('C', 'B', 1, 3), ('G', 'B', 1, 3), ('T', 'B', 1, 3), + ('A', 'H', 1, 3), ('C', 'H', 1, 3), ('G', 'D', 1, 3), ('T', 'D', 1, 3), + ('A', 'V', 1, 3), ('C', 'V', 1, 3), ('G', 'V', 1, 3), ('T', 'H', 1, 3), + ('A', 'N', 1, 4), ('C', 'N', 1, 4), ('G', 'N', 1, 4), ('T', 'N', 1, 4), + + ('R', 'A', 1, 2), ('Y', 'C', 1, 2), ('S', 'G', 1, 2), ('W', 'A', 1, 2), ('K', 'G', 1, 2), ('M', 'A', 1, 2), + ('R', 'G', 1, 2), ('Y', 'T', 1, 2), ('S', 'C', 1, 2), ('W', 'T', 1, 2), ('K', 'T', 1, 2), ('M', 'C', 1, 2), + ('R', 'S', 1, 4), ('Y', 'S', 1, 4), ('S', 'R', 1, 4), ('W', 'R', 1, 4), ('K', 'R', 1, 4), ('M', 'R', 1, 4), + ('R', 'W', 1, 4), ('Y', 'W', 1, 4), ('S', 'Y', 1, 4), ('W', 'Y', 1, 4), ('K', 'Y', 1, 4), ('M', 'Y', 1, 4), + ('R', 'K', 1, 4), ('Y', 'K', 1, 4), ('S', 'K', 1, 4), ('W', 'K', 1, 4), ('K', 'S', 1, 4), ('M', 'S', 1, 4), + ('R', 'M', 1, 4), ('Y', 'M', 1, 4), ('S', 'M', 1, 4), ('W', 'M', 1, 4), ('K', 'W', 1, 4), ('M', 'W', 1, 4), + ('R', 'B', 1, 6), ('Y', 'B', 2, 6), ('S', 'B', 2, 6), ('W', 'B', 1, 6), ('K', 'B', 2, 6), ('M', 'B', 1, 6), + ('R', 'D', 2, 6), ('Y', 'D', 1, 6), ('S', 'D', 1, 6), ('W', 'D', 2, 6), ('K', 'D', 2, 6), ('M', 'D', 1, 6), + ('R', 'H', 1, 6), ('Y', 'H', 2, 6), ('S', 'H', 1, 6), ('W', 'H', 2, 6), ('K', 'H', 1, 6), ('M', 'H', 2, 6), + ('R', 'V', 2, 6), ('Y', 'V', 1, 6), ('S', 'V', 2, 6), ('W', 'V', 1, 6), ('K', 'V', 1, 6), ('M', 'V', 2, 6), + ('R', 'N', 2, 8), ('Y', 'N', 2, 8), ('S', 'N', 2, 8), ('W', 'N', 2, 8), ('K', 'N', 1, 8), ('M', 'N', 2, 8), + ('R', 'R', 1, 2), ('Y', 'Y', 1, 2), ('S', 'S', 1, 2), ('W', 'W', 1, 2), ('K', 'K', 1, 2), ('M', 'M', 1, 2), + + ('B', 'C', 1, 3), ('D', 'A', 1, 3), ('H', 'A', 1, 3), ('V', 'A', 1, 3), ('N', 'A', 1, 4), + ('B', 'G', 1, 3), ('D', 'G', 1, 3), ('H', 'C', 1, 3), ('V', 'C', 1, 3), ('N', 'C', 1, 4), + ('B', 'T', 1, 3), ('D', 'T', 1, 3), ('H', 'T', 1, 3), ('V', 'D', 1, 3), ('N', 'G', 1, 4), + ('B', 'R', 1, 6), ('D', 'R', 2, 6), ('H', 'R', 1, 6), ('V', 'R', 2, 6), ('N', 'T', 1, 4), + ('B', 'Y', 2, 6), ('D', 'Y', 1, 6), ('H', 'Y', 2, 6), ('V', 'Y', 1, 6), ('N', 'R', 1, 4), + ('B', 'S', 2, 6), ('D', 'S', 1, 6), ('H', 'S', 1, 6), ('V', 'S', 2, 6), ('N', 'Y', 1, 4), + ('B', 'W', 1, 6), ('D', 'W', 2, 6), ('H', 'W', 2, 6), ('V', 'W', 1, 6), ('N', 'S', 1, 4), + ('B', 'K', 2, 6), ('D', 'K', 2, 6), ('H', 'K', 1, 6), ('V', 'K', 1, 6), ('N', 'W', 1, 4), + ('B', 'M', 1, 6), ('D', 'M', 1, 6), ('H', 'M', 2, 6), ('V', 'M', 2, 6), ('N', 'K', 1, 4), + ('B', 'D', 2, 9), ('D', 'B', 2, 9), ('H', 'B', 2, 9), ('V', 'B', 2, 9), ('N', 'M', 1, 4), + ('B', 'H', 2, 9), ('D', 'H', 2, 9), ('H', 'D', 2, 9), ('V', 'D', 2, 9), ('N', 'B', 1, 4), + ('B', 'V', 2, 9), ('D', 'V', 2, 9), ('H', 'V', 2, 9), ('V', 'H', 2, 9), ('N', 'D', 1, 4), + ('B', 'N', 3, 12), ('D', 'N', 3, 12), ('H', 'N', 3, 12), ('V', 'N', 3, 12), ('N', 'H', 1, 4), + ('B', 'B', 1, 3), ('D', 'D', 1, 3), ('H', 'H', 1, 3), ('V', 'V', 1, 3), ('N', 'V', 1, 4), + ('N', 'N', 1, 4), + ]; + + // Calculate the least common multiple of the denominators so we can + // scale the costs to integers. + let lcm = costs + .iter() + .map(|&(_, _, _, m)| m) + .collect::>() + .into_iter() + .fold(1, |a, b| a.lcm(&b)); + + // T and U are interchangeable. + let t_to_u = costs + .iter() + .filter(|&&(a, _, _, _)| a == 'T') + .map(|&(_, b, n, m)| ('U', b, n, m)) + .chain( + costs + .iter() + .filter(|&&(_, b, _, _)| b == 'T') + .map(|&(a, _, n, m)| (a, 'U', n, m)), + ) + .collect::>(); + + // The initial matrix with the default costs, except for gaps which are + // interchangeable. + let matrix = Self::default() + .with_sub_cost(b'-', b'.', T::ZERO) + .with_sub_cost(b'.', b'-', T::ZERO) + .scale(T::from(lcm)); + + // Add all costs to the matrix. + costs + .into_iter() + .chain(t_to_u) + // Scale the costs to integers. + .map(|(a, b, n, m)| (a, b, T::from(n * (lcm / m)))) + .flat_map(|(a, b, cost)| { + // Add the costs for the upper and lower case versions of the + // characters. + [ + (a, b, cost), + (a.to_ascii_lowercase(), b, cost), + (a, b.to_ascii_lowercase(), cost), + (a.to_ascii_lowercase(), b.to_ascii_lowercase(), cost), + ] + }) + // Cast the characters to bytes. + .map(|(a, b, cost)| (a as u8, b as u8, cost)) + // Add the costs to the substitution matrix. + .fold(matrix, |matrix, (a, b, cost)| matrix.with_sub_cost(a, b, cost)) + // Add affine gap penalties. + .with_gap_open(T::from(lcm * gap_open)) + .with_gap_ext(T::from(lcm)) + } + + /// The BLOSUM62 substitution matrix for proteins. + /// + /// See [here](https://en.wikipedia.org/wiki/BLOSUM) for more information. + /// + /// # Arguments + /// + /// * `gap_open`: The factor by which it is more expensive to open a gap + /// than to extend an existing gap. This defaults to 10. + #[must_use] + pub fn blosum62(gap_open: Option) -> Self { + let gap_open = gap_open.unwrap_or(10); + + #[rustfmt::skip] + let costs = [ + vec![ 9], // C + vec![-1, 4], // S + vec![-1, 1, 5], // T + vec![ 0, 1, 0, 4], // A + vec![-3, 0, -2, 0, 6], // G + vec![-3, -1, -1, -1, -2, 7], // P + vec![-3, 0, -1, -2, -1, -1, 6], // D + vec![-4, 0, -1, -1, -2, -1, 2, 5], // E + vec![-3, 0, -1, -1, -2, -1, 0, 2, 5], // Q + vec![-3, 1, 0, -2, 0, -2, 1, 0, 0, 6], // N + vec![-3, -1, -2, -2, -2, -2, 1, 0, 0, 1, 8], // H + vec![-3, -1, -1, -1, -2, -2, -2, 0, 1, 0, 0, 5], // R + vec![-3, 0, -1, -1, -2, -1, -1, 1, 1, 0, -1, 2, 5], // K + vec![-1, -1, -1, -1, -3, -2, -3, -2, 0, -2, -2, -1, -1, 5], // M + vec![-1, -2, -1, -1, -4, -3, -3, -3, -3, -3, -3, -3, -3, 1, 4], // I + vec![-1, -2, -1, -1, -4, -3, -4, -3, -2, -3, -3, -2, -2, 2, 2, 4], // L + vec![-1, -2, 0, 0, -3, -2, -3, -2, -2, -3, -3, -3, -2, 1, 3, 1, 4], // V + vec![-2, -3, -2, -3, -2, -4, -4, -3, -2, -4, -2, -3, -3, -1, -3, -2, -3, 11], // W + vec![-2, -2, -2, -2, -3, -3, -3, -2, -1, -2, 2, -2, -2, -1, -1, -1, -1, 2, 7], // Y + vec![-2, -2, -2, -2, -3, -4, -3, -3, -3, -3, -1, -3, -3, 0, 0, 0, -1, 1, 3, 6], // F + ]; + + // Calculate the maximum difference between any two substitution costs. + let max_delta = { + let (min, max) = costs.iter().flatten().fold((i32::MAX, i32::MIN), |(min, max), &cost| { + (Ord::min(min, cost), Ord::max(max, cost)) + }); + ::from(max.abs_diff(min)) + }; + + // The amino acid codes. + let codes = "CSTAGPDEQNHRKMILVWYF"; + + // The initial matrix with the default costs, except for gaps which are + // interchangeable. + let matrix = Self::default() + .with_sub_cost(b'-', b'.', T::ZERO) + .with_sub_cost(b'.', b'-', T::ZERO) + .scale(T::from(max_delta)); + + // Flatten the costs into a vector of (a, b, cost) tuples. + codes + .chars() + .zip(costs.iter()) + .flat_map(|(a, costs)| { + codes + .chars() + .zip(costs.iter()) + .map(move |(b, &cost)| (a, b, T::from(cost))) + }) + .flat_map(|(a, b, cost)| { + // Add the costs for the upper and lower case versions of the + // characters. + [ + (a, b, cost), + (a.to_ascii_lowercase(), b, cost), + (a, b.to_ascii_lowercase(), cost), + (a.to_ascii_lowercase(), b.to_ascii_lowercase(), cost), + ] + }) + // Convert the characters to bytes. + .map(|(a, b, cost)| (a as u8, b as u8, cost)) + // And combine them into a matrix. + .fold(matrix, |matrix, (a, b, cost)| { + matrix.with_sub_cost(a, b, cost).with_sub_cost(b, a, cost) + }) + // Convert the matrix into a form that can be used to minimize the + // edit distances. + .neg() + .normalize() + // Add affine gap penalties. + .with_gap_open(T::from(max_delta * gap_open)) + .with_gap_ext(T::from(max_delta)) + } +} diff --git a/crates/abd-clam/src/msa/aligner/mod.rs b/crates/abd-clam/src/msa/aligner/mod.rs new file mode 100644 index 000000000..a9cfb5632 --- /dev/null +++ b/crates/abd-clam/src/msa/aligner/mod.rs @@ -0,0 +1,252 @@ +//! Needleman-Wunsch algorithm for global sequence alignment. + +use distances::Number; + +mod cost_matrix; +pub mod ops; + +pub use cost_matrix::CostMatrix; +use ops::{Direction, Edit, Edits}; + +/// A table of edit distances between prefixes of two sequences. +type NwTable = Vec>; + +/// A Needleman-Wunsch aligner. +/// +/// This works with any sequence of bytes, and also provides helpers for working +/// with strings. +#[derive(Clone, Debug)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +pub struct Aligner { + /// The cost matrix for the alignment. + matrix: CostMatrix, + /// The gap character. + gap: u8, +} + +impl Aligner { + /// Create a new Needleman-Wunsch aligner that minimizes the cost. + pub fn new(matrix: &CostMatrix, gap: u8) -> Self { + Self { + matrix: matrix.clone(), + gap, + } + } + + /// Get the gap character. + #[must_use] + pub const fn gap(&self) -> u8 { + self.gap + } + + /// Compute the minimized edit distance between two sequences' DP table. + pub fn distance(&self, dp_table: &NwTable) -> T { + dp_table.last().and_then(|row| row.last()).map_or(T::ZERO, |&(d, _)| d) + } + + /// Compute the dynamic programming table for the Needleman-Wunsch algorithm. + /// + /// The DP table is a 2D array of edit distances between prefixes of the two + /// sequences. The value at position `(i, j)` is the edit distance between + /// the first `i` characters of the first sequence and the first `j` + /// characters of the second sequence. + /// + /// This implementation will minimize the edit distance. + /// + /// # Arguments + /// + /// * `x` - The first sequence. + /// * `y` - The second sequence. + /// + /// # Returns + /// + /// The DP table. + pub fn dp_table>(&self, x: &S, y: &S) -> NwTable { + let (x, y) = (x.as_ref(), y.as_ref()); + + // Initialize the DP table. + let mut table = vec![vec![(T::ZERO, Direction::Diagonal); x.len() + 1]; y.len() + 1]; + + // Initialize the first row to the cost of inserting characters from the + // first sequence. + for i in 1..table[0].len() { + let cost = table[0][i - 1].0 + self.matrix.gap_ext_cost(); + table[0][i] = (cost, Direction::Left); + } + + // Initialize the first column to the cost of inserting characters from + // the second sequence. + for j in 1..table.len() { + let cost = table[j - 1][0].0 + self.matrix.gap_ext_cost(); + table[j][0] = (cost, Direction::Up); + } + + // Fill in the DP table. + // On iteration (i, j), we will fill in the cell at (i + 1, j + 1). + for (i, &yc) in y.iter().enumerate() { + for (j, &xc) in x.iter().enumerate() { + // Compute the costs of the three possible operations. + let diag_cost = table[i][j].0 + self.matrix.sub_cost(xc, yc); + + // The cost of inserting a character depends on the previous + // operation. + let up_cost = table[i][j + 1].0 + + match table[i][j + 1].1 { + Direction::Up => self.matrix.gap_ext_cost(), + _ => self.matrix.gap_open_cost(), + }; + let left_cost = table[i + 1][j].0 + + match table[i + 1][j].1 { + Direction::Left => self.matrix.gap_ext_cost(), + _ => self.matrix.gap_open_cost(), + }; + + // Choose the operation with the minimum cost. If there is a tie, + // prefer the diagonal operation so that the aligned sequences + // are as short as possible. + table[i + 1][j + 1] = if diag_cost <= up_cost && diag_cost <= left_cost { + (diag_cost, Direction::Diagonal) + } else if up_cost <= left_cost { + (up_cost, Direction::Up) + } else { + (left_cost, Direction::Left) + }; + } + } + + table + } + + /// Align two sequences using the Needleman-Wunsch algorithm. + /// + /// # Arguments + /// + /// * `x` - The first sequence. + /// * `y` - The second sequence. + /// + /// # Returns + /// + /// The alignment distance and the aligned sequences as bytes. + pub fn align>(&self, x: &S, y: &S, table: &NwTable) -> [Vec; 2] { + let (x, y) = (x.as_ref(), y.as_ref()); + let [mut row_i, mut col_i] = [y.len(), x.len()]; + let [mut x_aligned, mut y_aligned] = [ + Vec::with_capacity(x.len() + y.len()), + Vec::with_capacity(x.len() + y.len()), + ]; + + while row_i > 0 || col_i > 0 { + match table[row_i][col_i].1 { + Direction::Diagonal => { + x_aligned.push(x[col_i - 1]); + y_aligned.push(y[row_i - 1]); + row_i -= 1; + col_i -= 1; + } + Direction::Up => { + x_aligned.push(self.gap); + y_aligned.push(y[row_i - 1]); + row_i -= 1; + } + Direction::Left => { + x_aligned.push(x[col_i - 1]); + y_aligned.push(self.gap); + col_i -= 1; + } + } + } + + x_aligned.reverse(); + y_aligned.reverse(); + + [x_aligned, y_aligned] + } + + /// Align two strings using the Needleman-Wunsch algorithm. + pub fn align_str>(&self, x: &S, y: &S, table: &NwTable) -> [String; 2] { + let [x_aligned, y_aligned] = self.align(&x.as_ref(), &y.as_ref(), table); + [ + String::from_utf8(x_aligned).unwrap_or_else(|e| unreachable!("We only added gaps: {e}")), + String::from_utf8(y_aligned).unwrap_or_else(|e| unreachable!("We only added gaps: {e}")), + ] + } + + /// Returns the `Edits` needed to align two sequences. + /// + /// Both sequences will need their respective `Edits` applied to them before + /// they are in alignment. + /// + /// # Arguments + /// + /// * `x` - The first sequence. + /// * `y` - The second sequence. + /// + /// # Returns + /// + /// The `Edits` needed to align the two sequences. + pub fn edits>(&self, x: &S, y: &S, table: &NwTable) -> [Edits; 2] { + let [x_aligned, y_aligned] = self.align(x, y, table); + [ + aligned_x_to_y(&x_aligned, &y_aligned), + aligned_x_to_y(&y_aligned, &x_aligned), + ] + } + + /// Returns the indices where gaps need to be inserted to align two + /// sequences. + pub fn alignment_gaps>(&self, x: &S, y: &S, table: &NwTable) -> [Vec; 2] { + let (x, y) = (x.as_ref(), y.as_ref()); + + let [mut row_i, mut col_i] = [y.len(), x.len()]; + let [mut x_gaps, mut y_gaps] = [Vec::new(), Vec::new()]; + + while row_i > 0 || col_i > 0 { + match table[row_i][col_i].1 { + Direction::Diagonal => { + row_i -= 1; + col_i -= 1; + } + Direction::Up => { + x_gaps.push(col_i); + row_i -= 1; + } + Direction::Left => { + y_gaps.push(row_i); + col_i -= 1; + } + } + } + + x_gaps.reverse(); + y_gaps.reverse(); + + [x_gaps, y_gaps] + } +} + +/// A helper function to create a sequence of `Edits` needed to align one +/// sequence to another. +fn aligned_x_to_y>(x: &S, y: &S) -> Edits { + let (_, edits) = x + .as_ref() + .iter() + .zip(y.as_ref().iter()) + .enumerate() + .filter(|(_, (&xc, &yc))| xc != yc) + .fold((0, Vec::new()), |(mut modifier, mut edits), (i, (&xc, &yc))| { + let i = i - modifier; + if xc == b'-' { + edits.push((i, Edit::Ins(yc))); + } else if yc == b'-' { + edits.push((i, Edit::Del)); + modifier += 1; + } else { + edits.push((i, Edit::Sub(yc))); + } + (modifier, edits) + }); + Edits::from(edits) +} diff --git a/crates/abd-clam/src/msa/aligner/ops.rs b/crates/abd-clam/src/msa/aligner/ops.rs new file mode 100644 index 000000000..01e7f594a --- /dev/null +++ b/crates/abd-clam/src/msa/aligner/ops.rs @@ -0,0 +1,63 @@ +//! Alignment operations for Needleman-Wunsch algorithm. + +/// The direction of the edit operation in the DP table. +#[derive(Clone, Eq, PartialEq, Debug)] +pub enum Direction { + /// Diagonal (Up and Left) for a match or substitution. + Diagonal, + /// Up for a gap in the first sequence. + Up, + /// Left for a gap in the second sequence. + Left, +} + +/// The type of edit operation. +pub enum Edit { + /// Substitution of one character for another. + Sub(u8), + /// Insertion of a character. + Ins(u8), + /// Deletion of a character. + Del, +} + +impl core::fmt::Debug for Edit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Sub(c) => f.debug_tuple("Sub").field(&(*c as char)).finish(), + Self::Ins(c) => f.debug_tuple("Ins").field(&(*c as char)).finish(), + Self::Del => write!(f, "Del"), + } + } +} + +/// The sequence of edits needed to turn one unaligned sequence into another. +#[derive(Debug)] +pub struct Edits(Vec<(usize, Edit)>); + +impl From> for Edits { + fn from(edits: Vec<(usize, Edit)>) -> Self { + Self(edits) + } +} + +impl AsRef<[(usize, Edit)]> for Edits { + fn as_ref(&self) -> &[(usize, Edit)] { + &self.0 + } +} + +impl IntoIterator for Edits { + type Item = (usize, Edit); + type IntoIter = std::vec::IntoIter<(usize, Edit)>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl FromIterator<(usize, Edit)> for Edits { + fn from_iter>(iter: I) -> Self { + Self(iter.into_iter().collect()) + } +} diff --git a/crates/abd-clam/src/msa/dataset/columnar.rs b/crates/abd-clam/src/msa/dataset/columnar.rs new file mode 100644 index 000000000..eaf9af823 --- /dev/null +++ b/crates/abd-clam/src/msa/dataset/columnar.rs @@ -0,0 +1,425 @@ +//! Recursively build up the MSA using the CLAM tree. + +use core::ops::Index; + +use std::string::FromUtf8Error; + +use distances::Number; +use rayon::prelude::*; + +use crate::{cakes::PermutedBall, cluster::ParCluster, dataset::ParDataset, Cluster, Dataset, FlatVec}; + +use super::super::Aligner; + +/// A multiple sequence alignment (MSA) builder. +pub struct Columnar { + /// The Needleman-Wunsch aligner. + aligner: Aligner, + /// The columns of the partial MSA. + columns: Vec>, +} + +impl Index for Columnar { + type Output = Vec; + + fn index(&self, index: usize) -> &Self::Output { + &self.columns[index] + } +} + +impl Columnar { + /// Create a new MSA builder. + #[must_use] + pub fn new(aligner: &Aligner) -> Self { + Self { + aligner: aligner.clone(), + columns: Vec::new(), + } + } + + /// Get the gap character. + #[must_use] + pub const fn gap(&self) -> u8 { + self.aligner.gap() + } + + /// Add a binary tree of `Cluster`s to the MSA. + #[must_use] + pub fn with_binary_tree(self, c: &PermutedBall, data: &D) -> Self + where + I: AsRef<[u8]>, + D: Dataset, + C: Cluster, + { + if c.children().is_empty() { + self.with_cluster(c, data) + } else { + if c.children().len() != 2 { + unreachable!("Binary tree has more than two children."); + } + let aligner = self.aligner; + let left = c.children()[0]; + let right = c.children()[1]; + + let l_msa = Self::new(&aligner).with_binary_tree(left, data); + let r_msa = Self::new(&aligner).with_binary_tree(right, data); + + let l_center = left + .iter_indices() + .position(|i| i == left.arg_center()) + .unwrap_or_else(|| unreachable!("Left center not found")); + let r_center = right + .iter_indices() + .position(|i| i == right.arg_center()) + .unwrap_or_else(|| unreachable!("Right center not found")); + + l_msa.merge(l_center, r_msa, r_center) + } + } + + /// Add a tree of `Cluster`s to the MSA. + #[must_use] + pub fn with_tree(self, c: &PermutedBall, data: &D) -> Self + where + I: AsRef<[u8]>, + D: Dataset, + C: Cluster, + { + if c.children().is_empty() { + self.with_cluster(c, data) + } else { + let children = c.children(); + let (&first, rest) = children.split_first().unwrap_or_else(|| unreachable!("No children")); + + let f_center = first + .iter_indices() + .position(|i| i == first.arg_center()) + .unwrap_or_else(|| unreachable!("First center not found")); + let first = Self::new(&self.aligner).with_tree(first, data); + + let (_, merged) = rest + .iter() + .map(|&o| { + let o_center = o + .iter_indices() + .position(|i| i == o.arg_center()) + .unwrap_or_else(|| unreachable!("Other center not found")); + (o_center, Self::new(&self.aligner).with_tree(o, data)) + }) + .fold((f_center, first), |(a_center, acc), (o_center, o)| { + (a_center, acc.merge(a_center, o, o_center)) + }); + + merged + } + } + + /// Replaces all sequences in the MSA with the given sequence. + /// + /// # Arguments + /// + /// * `sequence` - The sequence to add. + #[must_use] + pub fn with_sequence>(mut self, sequence: &I) -> Self { + self.columns = sequence.as_ref().iter().map(|&c| vec![c]).collect(); + self + } + + /// Adds sequences from a `Cluster` to the MSA. + #[must_use] + pub fn with_cluster(self, c: &C, data: &D) -> Self + where + I: AsRef<[u8]>, + D: Dataset, + C: Cluster, + { + ftlog::trace!( + "Adding cluster to MSA. Depth: {}, Cardinality: {}", + c.depth(), + c.cardinality() + ); + let indices = c.indices(); + let (&first, rest) = indices.split_first().unwrap_or_else(|| unreachable!("No indices")); + let first = Self::new(&self.aligner).with_sequence(data.get(first)); + rest.iter() + .map(|&i| data.get(i)) + .map(|s| Self::new(&self.aligner).with_sequence(s)) + .fold(first, |acc, s| acc.merge(0, s, 0)) + } + + /// The number of sequences in the MSA. + pub fn len(&self) -> usize { + self.columns.first().map_or(0, Vec::len) + } + + /// The number of columns in the MSA. + /// + /// If the MSA is empty, this will return 0. + #[must_use] + pub fn width(&self) -> usize { + self.columns.len() + } + + /// Whether the MSA is empty. + pub fn is_empty(&self) -> bool { + self.columns.is_empty() || self.columns.iter().all(Vec::is_empty) + } + + /// Get the columns of the MSA. + #[must_use] + pub fn columns(&self) -> &[Vec] { + &self.columns + } + + /// Get the sequence at the given index. + #[must_use] + pub fn get_sequence(&self, index: usize) -> Vec { + self.columns.iter().map(|col| col[index]).collect() + } + + /// Get the sequence at the given index. + /// + /// This is a convenience method that converts the sequence to a `String`. + /// + /// # Arguments + /// + /// * `index` - The index of the sequence to get. + /// + /// # Errors + /// + /// If the sequence is not valid UTF-8. + pub fn get_sequence_str(&self, index: usize) -> Result { + String::from_utf8(self.get_sequence(index)) + } + + /// Merge two MSAs. + #[must_use] + pub fn merge(mut self, s_center: usize, mut other: Self, o_center: usize) -> Self { + ftlog::trace!( + "Merging MSAs with cardinalities: {} and {}, and centers {s_center} and {o_center}", + self.len(), + other.len() + ); + let s_center = self.get_sequence(s_center); + let o_center = other.get_sequence(o_center); + + let table = self.aligner.dp_table(&s_center, &o_center); + let [s_to_o, o_to_s] = self.aligner.alignment_gaps(&s_center, &o_center, &table); + + for i in s_to_o { + self.add_gap(i).unwrap_or_else(|e| unreachable!("{e}")); + } + + for i in o_to_s { + other.add_gap(i).unwrap_or_else(|e| unreachable!("{e}")); + } + + let aligner = self.aligner; + let columns = self + .columns + .into_iter() + .zip(other.columns) + .map(|(mut x, mut y)| { + x.append(&mut y); + x + }) + .collect(); + + Self { aligner, columns } + } + + /// Add a gap column to the MSA. + /// + /// # Arguments + /// + /// - index: The index at which to add the gap column. + /// + /// # Errors + /// + /// - If the MSA is empty. + /// - If the index is greater than the number of columns. + pub fn add_gap(&mut self, index: usize) -> Result<(), String> { + if self.columns.is_empty() { + Err("MSA is empty.".to_string()) + } else if index > self.width() { + Err(format!( + "Index is greater than the width of the MSA: {index} > {}", + self.width() + )) + } else { + let gap_col = vec![self.gap(); self.columns[0].len()]; + self.columns.insert(index, gap_col); + Ok(()) + } + } + + /// Extract the multiple sequence alignment. + #[must_use] + pub fn extract_msa(&self) -> Vec> { + if self.is_empty() { + Vec::new() + } else { + (0..self.len()).map(|i| self.get_sequence(i)).collect() + } + } + + /// Extract the multiple sequence alignment over `String`s. + /// + /// # Errors + /// + /// - If any of the sequences are not valid UTF-8. + pub fn extract_msa_strings(&self) -> Result, FromUtf8Error> { + self.extract_msa().into_iter().map(String::from_utf8).collect() + } + + /// Extract the columns as a `FlatVec`. + #[must_use] + pub fn to_flat_vec_columns(&self) -> FlatVec, usize> { + FlatVec::new(self.columns.clone()) + .unwrap_or_else(|e| unreachable!("{e}")) + .with_dim_lower_bound(self.len()) + .with_dim_upper_bound(self.len()) + .with_name("ColWiseMSA") + } + + /// Extract the rows as a `FlatVec`. + #[must_use] + pub fn to_flat_vec_rows(&self) -> FlatVec, usize> { + FlatVec::new(self.extract_msa()) + .unwrap_or_else(|e| unreachable!("{e}")) + .with_dim_lower_bound(self.width()) + .with_dim_upper_bound(self.width()) + .with_name("RowWiseMSA") + } +} + +impl Columnar { + /// Parallel version of [`Columnar::with_binary_tree`](crate::msa::dataset::columnar::Columnar::with_binary_tree). + #[must_use] + pub fn par_with_binary_tree(self, c: &PermutedBall, data: &D) -> Self + where + I: AsRef<[u8]> + Send + Sync, + D: ParDataset, + C: ParCluster, + { + if c.children().is_empty() { + self.with_cluster(c, data) + } else { + if c.children().len() != 2 { + unreachable!("Binary tree has more than two children."); + } + let aligner = self.aligner; + let left = c.children()[0]; + let right = c.children()[1]; + + let (l_msa, r_msa) = rayon::join( + || Self::new(&aligner).par_with_binary_tree(left, data), + || Self::new(&aligner).par_with_binary_tree(right, data), + ); + + let l_center = left + .iter_indices() + .position(|i| i == left.arg_center()) + .unwrap_or_else(|| unreachable!("Left center not found")); + let r_center = right + .iter_indices() + .position(|i| i == right.arg_center()) + .unwrap_or_else(|| unreachable!("Right center not found")); + + l_msa.par_merge(l_center, r_msa, r_center) + } + } + + /// Parallel version of [`Columnar::with_tree`](crate::msa::dataset::columnar::Columnar::with_tree). + #[must_use] + pub fn par_with_tree(self, c: &PermutedBall, data: &D) -> Self + where + I: AsRef<[u8]> + Send + Sync, + D: ParDataset, + C: ParCluster, + { + if c.children().is_empty() { + self.with_cluster(c, data) + } else { + let children = c.children(); + let (&first, rest) = children.split_first().unwrap_or_else(|| unreachable!("No children")); + + let f_center = first + .iter_indices() + .position(|i| i == first.arg_center()) + .unwrap_or_else(|| unreachable!("First center not found")); + let first = Self::new(&self.aligner).with_tree(first, data); + + let (_, merged) = rest + .par_iter() + .map(|&o| { + let o_center = o + .iter_indices() + .position(|i| i == o.arg_center()) + .unwrap_or_else(|| unreachable!("Other center not found")); + (o_center, Self::new(&self.aligner).with_tree(o, data)) + }) + .collect::>() + .into_iter() + .fold((f_center, first), |(a_center, acc), (o_center, o)| { + (a_center, acc.par_merge(a_center, o, o_center)) + }); + + merged + } + } + + /// Parallel version of [`Columnar::merge`](crate::msa::dataset::columnar::Columnar::merge). + #[must_use] + pub fn par_merge(mut self, s_center: usize, mut other: Self, o_center: usize) -> Self { + ftlog::trace!( + "Parallel Merging MSAs with cardinalities: {} and {}, and centers {s_center} and {o_center}", + self.len(), + other.len() + ); + let s_center = self.get_sequence(s_center); + let o_center = other.get_sequence(o_center); + let table = self.aligner.dp_table(&s_center, &o_center); + let [s_to_o, o_to_s] = self.aligner.alignment_gaps(&s_center, &o_center, &table); + + for i in s_to_o { + self.add_gap(i).unwrap_or_else(|e| unreachable!("{e}")); + } + + for i in o_to_s { + other.add_gap(i).unwrap_or_else(|e| unreachable!("{e}")); + } + + let aligner = self.aligner; + let columns = self + .columns + .into_par_iter() + .zip(other.columns) + .map(|(mut x, mut y)| { + x.append(&mut y); + x + }) + .collect(); + + Self { aligner, columns } + } + + /// Parallel version of [`Columnar::extract_msa`](crate::msa::dataset::columnar::Columnar::extract_msa). + #[must_use] + pub fn par_extract_msa(&self) -> Vec> { + if self.is_empty() { + Vec::new() + } else { + (0..self.len()).into_par_iter().map(|i| self.get_sequence(i)).collect() + } + } + + /// Parallel version of [`Columnar::extract_msa_strings`](crate::msa::dataset::columnar::Columnar::extract_msa_strings). + /// + /// # Errors + /// + /// See [`Columnar::extract_msa_strings`](crate::msa::dataset::columnar::Columnar::extract_msa_strings). + pub fn par_extract_msa_strings(&self) -> Result, FromUtf8Error> { + self.extract_msa().into_par_iter().map(String::from_utf8).collect() + } +} diff --git a/crates/abd-clam/src/msa/dataset/mod.rs b/crates/abd-clam/src/msa/dataset/mod.rs new file mode 100644 index 000000000..b20dd4cca --- /dev/null +++ b/crates/abd-clam/src/msa/dataset/mod.rs @@ -0,0 +1,7 @@ +//! `Dataset` extensions that can be used to build MSAs. + +mod columnar; +mod msa; + +pub use columnar::Columnar; +pub use msa::MSA; diff --git a/crates/abd-clam/src/msa/dataset/msa.rs b/crates/abd-clam/src/msa/dataset/msa.rs new file mode 100644 index 000000000..0a02fb0bf --- /dev/null +++ b/crates/abd-clam/src/msa/dataset/msa.rs @@ -0,0 +1,629 @@ +//! A `Dataset` containing a multiple sequence alignment (MSA). + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + dataset::{AssociatesMetadata, AssociatesMetadataMut, ParDataset, Permutable}, + msa::NUM_CHARS, + utils::{self, LOG2_THRESH, SQRT_THRESH}, + Dataset, FlatVec, +}; + +use super::super::Aligner; + +/// A `Dataset` containing a multiple sequence alignment (MSA). +#[derive(Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +pub struct MSA, T: Number, Me> { + /// The Needleman-Wunsch aligner. + aligner: Aligner, + /// The data of the MSA. + data: FlatVec, + /// The name of the MSA. + name: String, +} + +impl, T: Number, Me> Dataset for MSA { + fn name(&self) -> &str { + &self.name + } + + fn with_name(self, name: &str) -> Self { + Self { + name: name.to_string(), + ..self + } + } + + fn cardinality(&self) -> usize { + self.data.cardinality() + } + + fn dimensionality_hint(&self) -> (usize, Option) { + self.data.dimensionality_hint() + } + + fn get(&self, index: usize) -> &I { + self.data.get(index) + } +} + +impl + Send + Sync, T: Number, Me: Send + Sync> ParDataset for MSA {} + +impl, T: Number, Me> Permutable for MSA { + fn permutation(&self) -> Vec { + self.data.permutation() + } + + fn set_permutation(&mut self, permutation: &[usize]) { + self.data.set_permutation(permutation); + } + + fn swap_two(&mut self, i: usize, j: usize) { + self.data.swap_two(i, j); + } +} + +impl, T: Number, Me> AssociatesMetadata for MSA { + fn metadata(&self) -> &[Me] { + self.data.metadata() + } + + fn metadata_at(&self, index: usize) -> &Me { + self.data.metadata_at(index) + } +} + +impl, T: Number, Me, Met: Clone> AssociatesMetadataMut> for MSA { + fn metadata_mut(&mut self) -> &mut [Me] { + as AssociatesMetadataMut>>::metadata_mut(&mut self.data) + } + + fn metadata_at_mut(&mut self, index: usize) -> &mut Me { + as AssociatesMetadataMut>>::metadata_at_mut(&mut self.data, index) + } + + fn with_metadata(self, metadata: &[Met]) -> Result, String> { + self.data.with_metadata(metadata).map(|data| MSA { + aligner: self.aligner, + data, + name: self.name, + }) + } + + fn transform_metadata Met>(self, f: F) -> MSA { + MSA { + aligner: self.aligner, + data: self.data.transform_metadata(f), + name: self.name, + } + } +} + +impl, T: Number, Me> MSA { + /// Creates a new MSA. + /// + /// # Arguments + /// + /// * `aligner` - The Needleman-Wunsch aligner. + /// * `data` - The data of the MSA. + /// + /// # Errors + /// + /// - If any sequence in the MSA is empty. + /// - If the sequences in the MSA have different lengths. + pub fn new(aligner: &Aligner, data: FlatVec) -> Result { + let (min_len, max_len) = data + .items() + .iter() + .map(|item| item.as_ref().len()) + .fold((usize::MAX, 0), |(min, max), len| { + (Ord::min(min, len), Ord::max(max, len)) + }); + + if min_len == 0 { + Err("Empty sequences are not allowed in an MSA.".to_string()) + } else if min_len == max_len { + let name = format!("MSA({})", data.name()); + Ok(Self { + aligner: aligner.clone(), + data, + name, + }) + } else { + Err("Sequences in an MSA must have the same length.".to_string()) + } + } + + /// Returns the Needleman-Wunsch aligner. + pub const fn aligner(&self) -> &Aligner { + &self.aligner + } + + /// Returns the data of the MSA. + pub const fn data(&self) -> &FlatVec { + &self.data + } + + /// The gap character in the MSA. + #[must_use] + pub const fn gap(&self) -> u8 { + self.aligner.gap() + } + + /// Returns the width of the MSA. + #[must_use] + pub fn width(&self) -> usize { + self.dimensionality_hint().0 + } + + /// Swaps between the row/col major order of the MSA. + /// + /// This will convert a row-major MSA to a col-major MSA and vice versa. + #[must_use] + pub fn change_major(&self) -> MSA, T, usize> { + let rows = self.data.items().iter().map(I::as_ref).collect::>(); + let cols = (0..self.width()) + .map(|i| rows.iter().map(|row| row[i]).collect::>()) + .collect::>(); + let name = format!("Columnar({})", self.name); + let data = FlatVec::new(cols).unwrap_or_else(|e| unreachable!("Failed to create a FlatVec: {}", e)); + MSA { + aligner: self.aligner.clone(), + data, + name, + } + } + + /// Scores each pair of columns in the MSA, applying a penalty for gaps and + /// mismatches. + /// + /// This should only be used with col-major MSA and will give nonsensical + /// results with row-major MSA. + #[must_use] + pub fn scoring_columns(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 { + let score = self + .data + .items() + .iter() + .map(AsRef::as_ref) + .map(|c| sc_inner(c, gap_char, gap_penalty, mismatch_penalty)) + .sum::(); + score.as_f32() / utils::n_pairs(self.cardinality()).as_f32() + } + + /// Calculates the mean and maximum `p-distance`s of all pairwise + /// alignments in the MSA. + #[must_use] + pub fn p_distance_stats(&self, gap_char: u8) -> (f32, f32) { + let p_distances = self.p_distances(gap_char); + let n_pairs = p_distances.len(); + let (sum, max) = p_distances + .into_iter() + .fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist))); + let avg = sum / n_pairs.as_f32(); + (avg, max) + } + + /// Same as `p_distance_stats`, but only estimates the score for a subset of + /// the pairwise alignments. + #[must_use] + pub fn p_distance_stats_subsample(&self, gap_char: u8) -> (f32, f32) { + let p_distances = self.p_distances_subsample(gap_char); + let n_pairs = p_distances.len(); + let (sum, max) = p_distances + .into_iter() + .fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist))); + let avg = sum / n_pairs.as_f32(); + (avg, max) + } + + /// Calculates the `p-distance` of each pairwise alignment in the MSA. + fn p_distances(&self, gap_char: u8) -> Vec { + let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char); + self.apply_pairwise(&self.indices().collect::>(), scorer) + .collect() + } + + /// Same as `p_distances`, but only estimates the score for a subset of the + /// pairwise alignments. + fn p_distances_subsample(&self, gap_char: u8) -> Vec { + let indices = utils::choose_samples(&self.indices().collect::>(), SQRT_THRESH, LOG2_THRESH); + let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char); + self.apply_pairwise(&indices, scorer).collect() + } + + /// Scores the MSA using the distortion of the Levenshtein edit distance + /// and the Hamming distance between each pair of sequences. + #[must_use] + pub fn distance_distortion(&self, gap_char: u8) -> f32 { + let score = self.sum_of_pairs(&self.indices().collect::>(), |s1, s2| dd_inner(s1, s2, gap_char)); + score.as_f32() / utils::n_pairs(self.cardinality()).as_f32() + } + + /// Same as `distance_distortion`, but only estimates the score for a subset + /// of the pairwise alignments. + #[must_use] + pub fn distance_distortion_subsample(&self, gap_char: u8) -> f32 { + let indices = utils::choose_samples(&self.indices().collect::>(), SQRT_THRESH / 4, LOG2_THRESH / 4); + let score = self.sum_of_pairs(&indices, |s1, s2| dd_inner(s1, s2, gap_char)); + score.as_f32() / utils::n_pairs(indices.len()).as_f32() + } + + /// Scores each pairwise alignment in the MSA, applying a penalty for gaps + /// and mismatches. + /// + /// # Arguments + /// + /// * `gap_penalty` - The penalty for a gap. + /// * `mismatch_penalty` - The penalty for a mismatch. + /// + /// # Returns + /// + /// The sum of the penalties for all pairwise alignments divided by the + /// number of pairwise alignments. + #[must_use] + pub fn scoring_pairwise(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 { + let scorer = |s1: &[u8], s2: &[u8]| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty); + let score = self.sum_of_pairs(&self.indices().collect::>(), scorer); + score.as_f32() / utils::n_pairs(self.cardinality()).as_f32() + } + + /// Same as `scoring_pairwise`, but only estimates the score for a subset of + /// the pairwise alignments. + #[must_use] + pub fn scoring_pairwise_subsample(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 { + let indices = utils::choose_samples(&self.indices().collect::>(), SQRT_THRESH, LOG2_THRESH); + let scorer = |s1: &[u8], s2: &[u8]| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty); + let score = self.sum_of_pairs(&indices, scorer); + score.as_f32() / utils::n_pairs(indices.len()).as_f32() + } + + /// Scores each pairwise alignment in the MSA, applying penalties for + /// opening a gap, extending a gap, and mismatches. + /// + /// # Arguments + /// + /// * `gap_open_penalty` - The penalty for opening a gap. + /// * `gap_ext_penalty` - The penalty for extending a gap. + /// * `mismatch_penalty` - The penalty for a mismatch. + /// + /// # Returns + /// + /// The sum of the penalties for all pairwise alignments divided by the + /// number of pairwise alignments. + #[must_use] + pub fn weighted_scoring_pairwise( + &self, + gap_char: u8, + gap_open_penalty: usize, + gap_ext_penalty: usize, + mismatch_penalty: usize, + ) -> f32 { + let scorer = + |s1: &[u8], s2: &[u8]| wsp_inner(s1, s2, gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty); + let score = self.sum_of_pairs(&self.indices().collect::>(), scorer); + score.as_f32() / utils::n_pairs(self.cardinality()).as_f32() + } + + /// Same as `weighted_scoring_pairwise`, but only estimates the score for a subset of + /// the pairwise alignments. + #[must_use] + pub fn weighted_scoring_pairwise_subsample( + &self, + gap_char: u8, + gap_open_penalty: usize, + gap_ext_penalty: usize, + mismatch_penalty: usize, + ) -> f32 { + let indices = utils::choose_samples(&self.indices().collect::>(), SQRT_THRESH / 4, LOG2_THRESH / 4); + let scorer = + |s1: &[u8], s2: &[u8]| wsp_inner(s1, s2, gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty); + let score = self.sum_of_pairs(&indices, scorer); + score.as_f32() / utils::n_pairs(indices.len()).as_f32() + } + + /// Applies a pairwise scorer to all pairs of sequences in the MSA. + fn apply_pairwise<'a, F, G: Number>(&'a self, indices: &'a [usize], scorer: F) -> impl Iterator + 'a + where + F: (Fn(&[u8], &[u8]) -> G) + 'a, + { + indices + .iter() + .enumerate() + .flat_map(move |(i, &s1)| indices.iter().skip(i + 1).map(move |&s2| (s1, s2))) + .map(|(s1, s2)| (self.get(s1).as_ref(), self.get(s2).as_ref())) + .map(move |(s1, s2)| scorer(s1, s2)) + } + + /// Calculate the sum of the pairwise scores for a given scorer. + fn sum_of_pairs(&self, indices: &[usize], scorer: F) -> G + where + F: Fn(&[u8], &[u8]) -> G, + { + self.apply_pairwise(indices, scorer).sum() + } +} + +impl + Send + Sync, T: Number, Me: Send + Sync> MSA { + /// Parallel version of [`MSA::change_major`](crate::msa::dataset::msa::MSA::change_major). + #[must_use] + pub fn par_change_major(&self) -> MSA, T, usize> { + let rows = self.data.items().par_iter().map(I::as_ref).collect::>(); + let cols = (0..self.width()) + .into_par_iter() + .map(|i| rows.iter().map(|row| row[i]).collect::>()) + .collect::>(); + let name = format!("Columnar({})", self.name); + let data = FlatVec::new(cols).unwrap_or_else(|e| unreachable!("Failed to create a FlatVec: {}", e)); + MSA { + aligner: self.aligner.clone(), + data, + name, + } + } + + /// Parallel version of [`MSA::scoring_columns`](crate::msa::dataset::msa::MSA::scoring_columns). + #[must_use] + pub fn par_scoring_columns(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 { + let num_seqs = self.get(0).as_ref().len(); + let score = self + .data + .items() + .par_iter() + .map(AsRef::as_ref) + .map(|c| sc_inner(c, gap_char, gap_penalty, mismatch_penalty)) + .sum::(); + score.as_f32() / utils::n_pairs(num_seqs).as_f32() + } + + /// Parallel version of [`MSA::p_distance_stats`](crate::msa::dataset::msa::MSA::p_distance_stats). + #[must_use] + pub fn par_p_distance_stats(&self, gap_char: u8) -> (f32, f32) { + let p_dists = self.par_p_distances(gap_char); + let n_pairs = p_dists.len(); + let (sum, max) = p_dists + .into_iter() + .fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist))); + let avg = sum / n_pairs.as_f32(); + (avg, max) + } + + /// Parallel version of [`MSA::p_distance_stats_subsample`](crate::msa::dataset::msa::MSA::p_distance_stats_subsample). + #[must_use] + pub fn par_p_distance_stats_subsample(&self, gap_char: u8) -> (f32, f32) { + let p_dists = self.par_p_distances_subsample(gap_char); + let n_pairs = p_dists.len(); + let (sum, max) = p_dists + .into_iter() + .fold((0.0, 0.0), |(sum, max), dist| (sum + dist, f32::max(max, dist))); + let avg = sum / n_pairs.as_f32(); + (avg, max) + } + + /// Parallel version of [`MSA::p_distances`](crate::msa::dataset::msa::MSA::p_distances). + fn par_p_distances(&self, gap_char: u8) -> Vec { + let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char); + self.par_apply_pairwise(&self.indices().collect::>(), scorer) + .collect() + } + + /// Parallel version of [`MSA::p_distances_subsample`](crate::msa::dataset::msa::MSA::p_distances_subsample). + fn par_p_distances_subsample(&self, gap_char: u8) -> Vec { + let indices = utils::choose_samples(&self.indices().collect::>(), SQRT_THRESH, LOG2_THRESH); + let scorer = |s1: &[u8], s2: &[u8]| pd_inner(s1, s2, gap_char); + self.par_apply_pairwise(&indices, scorer).collect() + } + + /// Parallel version of [`MSA::distance_distortion`](crate::msa::dataset::msa::MSA::distance_distortion). + #[must_use] + pub fn par_distance_distortion(&self, gap_char: u8) -> f32 { + let score = self.par_sum_of_pairs(&self.indices().collect::>(), |s1, s2| dd_inner(s1, s2, gap_char)); + score.as_f32() / utils::n_pairs(self.cardinality()).as_f32() + } + + /// Parallel version of [`MSA::distance_distortion_subsample`](crate::msa::dataset::msa::MSA::distance_distortion_subsample). + #[must_use] + pub fn par_distance_distortion_subsample(&self, gap_char: u8) -> f32 { + let indices = utils::choose_samples(&self.indices().collect::>(), SQRT_THRESH / 8, LOG2_THRESH / 8); + let score = self.par_sum_of_pairs(&indices, |s1, s2| dd_inner(s1, s2, gap_char)); + score.as_f32() / utils::n_pairs(indices.len()).as_f32() + } + + /// Parallel version of [`MSA::scoring_pairwise`](crate::msa::dataset::msa::MSA::scoring_pairwise). + #[must_use] + pub fn par_scoring_pairwise(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 { + let scorer = |s1: &[u8], s2: &[u8]| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty); + let score = self.par_sum_of_pairs(&self.indices().collect::>(), scorer); + score.as_f32() / utils::n_pairs(self.cardinality()).as_f32() + } + + /// Parallel version of [`MSA::scoring_pairwise_subsample`](crate::msa::dataset::msa::MSA::scoring_pairwise_subsample). + #[must_use] + pub fn par_scoring_pairwise_subsample(&self, gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> f32 { + let indices = utils::choose_samples(&self.indices().collect::>(), SQRT_THRESH, LOG2_THRESH); + let scorer = |s1: &[u8], s2: &[u8]| sp_inner(s1, s2, gap_char, gap_penalty, mismatch_penalty); + let score = self.par_sum_of_pairs(&indices, scorer); + score.as_f32() / utils::n_pairs(indices.len()).as_f32() + } + + /// Parallel version of [`MSA::weighted_scoring_pairwise`](crate::msa::dataset::msa::MSA::weighted_scoring_pairwise). + #[must_use] + pub fn par_weighted_scoring_pairwise( + &self, + gap_char: u8, + gap_open_penalty: usize, + gap_ext_penalty: usize, + mismatch_penalty: usize, + ) -> f32 { + let scorer = + |s1: &[u8], s2: &[u8]| wsp_inner(s1, s2, gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty); + let score = self.par_sum_of_pairs(&self.indices().collect::>(), scorer); + score.as_f32() / utils::n_pairs(self.cardinality()).as_f32() + } + + /// Parallel version of [`MSA::weighted_scoring_pairwise_subsample`](crate::msa::dataset::msa::MSA::weighted_scoring_pairwise_subsample). + #[must_use] + pub fn par_weighted_scoring_pairwise_subsample( + &self, + gap_char: u8, + gap_open_penalty: usize, + gap_ext_penalty: usize, + mismatch_penalty: usize, + ) -> f32 { + let indices = utils::choose_samples(&self.indices().collect::>(), SQRT_THRESH, LOG2_THRESH); + let scorer = + |s1: &[u8], s2: &[u8]| wsp_inner(s1, s2, gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty); + let score = self.par_sum_of_pairs(&indices, scorer); + score.as_f32() / utils::n_pairs(indices.len()).as_f32() + } + + /// Parallel version of [`MSA::apply_pairwise`](crate::msa::dataset::msa::MSA::apply_pairwise). + fn par_apply_pairwise<'a, F, G: Number>( + &'a self, + indices: &'a [usize], + scorer: F, + ) -> impl ParallelIterator + 'a + where + F: (Fn(&[u8], &[u8]) -> G) + Send + Sync + 'a, + { + indices + .par_iter() + .enumerate() + .flat_map(move |(i, &s1)| indices.par_iter().skip(i + 1).map(move |&s2| (s1, s2))) + .map(|(s1, s2)| (self.get(s1).as_ref(), self.get(s2).as_ref())) + .map(move |(s1, s2)| scorer(s1, s2)) + } + + /// Parallel version of [`MSA::sum_of_pairs`](crate::msa::dataset::msa::MSA::sum_of_pairs). + fn par_sum_of_pairs(&self, indices: &[usize], scorer: F) -> G + where + F: (Fn(&[u8], &[u8]) -> G) + Send + Sync, + { + self.par_apply_pairwise(indices, scorer).sum() + } +} + +/// Scores a single pair of columns in the MSA, applying a penalty for gaps and +/// mismatches. +fn sc_inner(col: &[u8], gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> usize { + // Create a frequency count of the characters in the column. + let freqs = col.iter().fold([0; NUM_CHARS], |mut freqs, &c| { + freqs[c as usize] += 1; + freqs + }); + + // Start scoring the column. + let mut score = 0; + + // Calculate the number of pairs of characters of which one is a gap and + // apply the gap penalty. + let num_gaps = freqs[gap_char as usize]; + score += num_gaps * (col.len() - num_gaps) * gap_penalty / 2; + + // Get the frequencies of non-gap characters with non-zero frequency. + let freqs = freqs + .into_iter() + .enumerate() + .filter(|&(i, f)| (f > 0) && (i != gap_char as usize)) + .map(|(_, f)| f) + .collect::>(); + + // For each combinatorial pair, add mismatch penalties. + freqs + .iter() + .enumerate() + .flat_map(|(i, &f1)| freqs.iter().skip(i + 1).map(move |&f2| (f1, f2))) + .fold(score, |score, (f1, f2)| score + f1 * f2 * mismatch_penalty) +} + +/// Removes gap-only columns from two aligned sequences. +fn remove_gap_only_cols(s1: &[u8], s2: &[u8], gap_char: u8) -> (Vec, Vec) { + s1.iter() + .zip(s2.iter()) + .filter(|(&a, &b)| !(a == gap_char && b == gap_char)) + .unzip() +} + +/// Scores a single pairwise alignment in the MSA, applying a penalty for +/// gaps and mismatches. +fn sp_inner(s1: &[u8], s2: &[u8], gap_char: u8, gap_penalty: usize, mismatch_penalty: usize) -> usize { + let (s1, s2) = remove_gap_only_cols(s1, s2, gap_char); + s1.iter().zip(s2.iter()).fold(0, |score, (&a, &b)| { + if a == gap_char || b == gap_char { + score + gap_penalty + } else if a != b { + score + mismatch_penalty + } else { + score + } + }) +} + +/// Scores a single pairwise alignment in the MSA, applying a penalty for +/// opening a gap, extending a gap, and mismatches. +fn wsp_inner( + s1: &[u8], + s2: &[u8], + gap_char: u8, + gap_open_penalty: usize, + gap_ext_penalty: usize, + mismatch_penalty: usize, +) -> usize { + let (s1, s2) = remove_gap_only_cols(s1, s2, gap_char); + + let start = if s1[0] == gap_char || s2[0] == gap_char { + gap_open_penalty + } else if s1[0] != s2[0] { + mismatch_penalty + } else { + 0 + }; + + s1.iter() + .zip(s1.iter().skip(1)) + .zip(s2.iter().zip(s1.iter().skip(1))) + .fold(start, |score, ((&a1, &a2), (&b1, &b2))| { + if (a2 == gap_char && a1 != gap_char) || (b2 == gap_char && b1 != gap_char) { + score + gap_open_penalty + } else if a2 == gap_char || b2 == gap_char { + score + gap_ext_penalty + } else if a2 != b2 { + score + mismatch_penalty + } else { + score + } + }) +} + +/// Measures the distortion of the Levenshtein edit distance between the +/// unaligned sequences and the Hamming distance between the aligned sequences. +fn dd_inner(s1: &[u8], s2: &[u8], gap_char: u8) -> f32 { + let (s1, s2) = remove_gap_only_cols(s1, s2, gap_char); + let ham = s1.iter().zip(s2.iter()).filter(|(&a, &b)| a != b).count(); + + let s1 = s1.iter().filter(|&&c| c != gap_char).copied().collect::>(); + let s2 = s2.iter().filter(|&&c| c != gap_char).copied().collect::>(); + let lev = stringzilla::sz::edit_distance(s1, s2); + + if lev == 0 { + 1.0 + } else { + ham.as_f32() / lev.as_f32() + } +} + +/// Calculates the p-distance of a pair of sequences. +fn pd_inner(s1: &[u8], s2: &[u8], gap_char: u8) -> f32 { + let (s1, s2) = remove_gap_only_cols(s1, s2, gap_char); + let num_mismatches = s1 + .iter() + .zip(s2.iter()) + .filter(|(&a, &b)| a != gap_char && b != gap_char && a != b) + .count(); + num_mismatches.as_f32() / s1.len().as_f32() +} diff --git a/crates/abd-clam/src/msa/mod.rs b/crates/abd-clam/src/msa/mod.rs new file mode 100644 index 000000000..6cc5b6f70 --- /dev/null +++ b/crates/abd-clam/src/msa/mod.rs @@ -0,0 +1,12 @@ +//! Multiple Sequence Alignment with CLAM + +mod aligner; +mod dataset; +mod sequence; + +pub use aligner::{ops, Aligner, CostMatrix}; +pub use dataset::{Columnar, MSA}; +pub use sequence::Sequence; + +/// The number of characters. +pub(crate) const NUM_CHARS: usize = 1 + (u8::MAX as usize); diff --git a/crates/abd-clam/src/msa/sequence.rs b/crates/abd-clam/src/msa/sequence.rs new file mode 100644 index 000000000..d54242213 --- /dev/null +++ b/crates/abd-clam/src/msa/sequence.rs @@ -0,0 +1,116 @@ +//! A wrapper around a `String` to use in multiple sequence alignment. + +use distances::Number; + +use super::Aligner; + +/// A wrapper around a `String` to use in multiple sequence alignment. +#[derive(Clone)] +pub struct Sequence<'a, T: Number> { + /// The wrapped string. + seq: String, + /// The aligner to use among the sequences. + aligner: Option<&'a Aligner>, +} + +#[cfg(feature = "disk-io")] +impl bitcode::Encode for Sequence<'_, T> { + const ENCODE_MAX: usize = String::ENCODE_MAX; + const ENCODE_MIN: usize = String::ENCODE_MIN; + + fn encode( + &self, + encoding: impl bitcode::encoding::Encoding, + writer: &mut impl bitcode::write::Write, + ) -> bitcode::Result<()> { + self.seq.encode(encoding, writer) + } +} + +#[cfg(feature = "disk-io")] +impl bitcode::Decode for Sequence<'_, T> { + const DECODE_MAX: usize = String::DECODE_MAX; + const DECODE_MIN: usize = String::DECODE_MIN; + + fn decode( + encoding: impl bitcode::encoding::Encoding, + reader: &mut impl bitcode::read::Read, + ) -> bitcode::Result { + let seq = String::decode(encoding, reader)?; + Ok(Self { seq, aligner: None }) + } +} + +impl<'a, T: Number> Sequence<'a, T> { + /// Creates a new `Sequence` from a string. + /// + /// # Arguments + /// + /// * `seq`: The string to wrap. + /// * `aligner`: The aligner to use among the sequences. + /// + /// # Returns + /// + /// The wrapped string. + #[must_use] + pub const fn new(seq: String, aligner: Option<&'a Aligner>) -> Self { + Self { seq, aligner } + } + + /// Returns the length of the sequence. + #[must_use] + pub fn len(&self) -> usize { + self.seq.len() + } + + /// Whether the sequence is empty. + #[must_use] + pub fn is_empty(&self) -> bool { + self.seq.is_empty() + } + + /// Returns the aligner for the sequence. + #[must_use] + pub const fn aligner(&self) -> Option<&'a Aligner> { + self.aligner + } + + /// Sets the aligner for the sequence. + #[must_use] + pub const fn with_aligner(mut self, aligner: &'a Aligner) -> Self { + self.aligner = Some(aligner); + self + } + + /// Returns the sequence + #[must_use] + pub fn seq(&self) -> &str { + &self.seq + } +} + +impl core::fmt::Debug for Sequence<'_, T> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(&self.seq) + } +} + +impl PartialEq for Sequence<'_, T> { + fn eq(&self, other: &Self) -> bool { + self.seq == other.seq + } +} + +impl Eq for Sequence<'_, T> {} + +impl AsRef for Sequence<'_, T> { + fn as_ref(&self) -> &str { + &self.seq + } +} + +impl AsRef<[u8]> for Sequence<'_, T> { + fn as_ref(&self) -> &[u8] { + self.seq.as_bytes() + } +} diff --git a/crates/abd-clam/src/pancakes/cluster/mod.rs b/crates/abd-clam/src/pancakes/cluster/mod.rs new file mode 100644 index 000000000..2468b6e32 --- /dev/null +++ b/crates/abd-clam/src/pancakes/cluster/mod.rs @@ -0,0 +1,5 @@ +//! `Cluster` extensions for `PanCAKES`. + +mod squishy_ball; + +pub use squishy_ball::{SquishCosts, SquishyBall}; diff --git a/crates/abd-clam/src/pancakes/cluster/squishy_ball.rs b/crates/abd-clam/src/pancakes/cluster/squishy_ball.rs new file mode 100644 index 000000000..8a64905d9 --- /dev/null +++ b/crates/abd-clam/src/pancakes/cluster/squishy_ball.rs @@ -0,0 +1,451 @@ +//! An adaptation of `Ball` that allows for compression of the dataset and +//! search in the compressed space. + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::PermutedBall, + cluster::{ + adapter::{Adapter, BallAdapter, ParAdapter, ParBallAdapter, ParParams, Params}, + ParCluster, + }, + dataset::{ParDataset, Permutable}, + metric::ParMetric, + Ball, Cluster, Dataset, Metric, +}; + +use super::super::{CodecData, Compressible, Decodable, Decompressible, Encodable, ParCompressible, ParDecompressible}; + +#[cfg(feature = "disk-io")] +use std::io::{Read, Write}; + +#[cfg(feature = "disk-io")] +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; + +/// A `Cluster` for use in compressive search. +#[derive(Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +#[cfg_attr(feature = "disk-io", bitcode(recursive))] +pub struct SquishyBall> { + /// The `Cluster` type that the `OffsetBall` is based on. + source: PermutedBall, + /// Parameters for the `OffsetBall`. + costs: SquishCosts, + /// The children of the `Cluster`. + children: Vec>, +} + +impl + core::fmt::Debug> core::fmt::Debug for SquishyBall { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SquishyBall") + .field("source", &self.source) + .field("recursive_cost", &self.costs.recursive) + .field("unitary_cost", &self.costs.unitary) + .field("minimum_cost", &self.costs.minimum) + .field("children", &!self.children.is_empty()) + .finish() + } +} + +impl> SquishyBall { + /// Get the unitary cost of the `SquishyBall`. + pub const fn unitary_cost(&self) -> T { + self.costs.unitary + } + + /// Get the recursive cost of the `SquishyBall`. + pub const fn recursive_cost(&self) -> T { + self.costs.recursive + } + + /// Gets the offset of the cluster's indices in its dataset. + pub const fn offset(&self) -> usize { + self.source.offset() + } + + /// Trims the tree by removing empty children of clusters whose unitary cost + /// is greater than the recursive cost. + pub fn trim(&mut self, min_depth: usize) { + if !self.children.is_empty() { + if (self.costs.unitary <= self.costs.recursive) && (self.depth() >= min_depth) { + self.children.clear(); + } else { + self.children.iter_mut().for_each(|c| c.trim(min_depth)); + } + } + } + + /// Sets the costs for the tree. + pub fn set_costs, M: Metric>(&mut self, data: &D, metric: &M) { + self.set_unitary_cost(data, metric); + if self.children.is_empty() { + self.costs.recursive = T::ZERO; + } else { + self.children.iter_mut().for_each(|c| c.set_costs(data, metric)); + self.set_recursive_cost(data, metric); + } + self.set_min_cost(); + } + + /// Calculates the unitary cost of the `Cluster`. + fn set_unitary_cost, M: Metric>(&mut self, data: &D, metric: &M) { + self.costs.unitary = self + .source + .iter_indices() + .map(|i| data.one_to_one(i, self.arg_center(), metric)) + .sum(); + } + + /// Calculates the recursive cost of the `Cluster`. + fn set_recursive_cost, M: Metric>(&mut self, data: &D, metric: &M) { + if self.children.is_empty() { + self.costs.recursive = T::ZERO; + } else { + let children = self.children(); + let child_costs = children.iter().map(|c| c.costs.minimum).sum::(); + let child_centers = children.iter().map(|c| c.arg_center()).collect::>(); + self.costs.recursive = child_costs + + data + .one_to_many(self.arg_center(), &child_centers, metric) + .map(|(_, d)| d) + .sum::(); + } + } + + /// Sets the minimum cost of the `Cluster`. + fn set_min_cost(&mut self) { + self.costs.minimum = if self.costs.recursive < self.costs.unitary { + self.costs.recursive + } else { + self.costs.unitary + }; + } +} + +impl> SquishyBall { + /// Sets the costs for the tree. + pub fn par_set_costs, M: ParMetric>(&mut self, data: &D, metric: &M) { + self.par_set_unitary_cost(data, metric); + if self.children.is_empty() { + self.costs.recursive = T::ZERO; + } else { + self.children.par_iter_mut().for_each(|c| c.par_set_costs(data, metric)); + self.par_set_recursive_cost(data, metric); + } + self.set_min_cost(); + } + + /// Calculates the unitary cost of the `Cluster`. + fn par_set_unitary_cost, M: ParMetric>(&mut self, data: &D, metric: &M) { + self.costs.unitary = self + .par_indices() + .map(|i| data.par_one_to_one(i, self.arg_center(), metric)) + .sum(); + } + + /// Calculates the recursive cost of the `Cluster`. + fn par_set_recursive_cost, M: ParMetric>(&mut self, data: &D, metric: &M) { + if self.children.is_empty() { + self.costs.recursive = T::ZERO; + } else { + let children = self.children(); + let child_costs = children.iter().map(|c| c.costs.minimum).sum::(); + let child_centers = children.iter().map(|c| c.arg_center()).collect::>(); + self.costs.recursive = child_costs + + data + .par_one_to_many(self.arg_center(), &child_centers, metric) + .map(|(_, d)| d) + .sum(); + } + } +} + +impl> Cluster for SquishyBall { + fn depth(&self) -> usize { + self.source.depth() + } + + fn cardinality(&self) -> usize { + self.source.cardinality() + } + + fn arg_center(&self) -> usize { + self.source.arg_center() + } + + fn set_arg_center(&mut self, arg_center: usize) { + self.source.set_arg_center(arg_center); + } + + fn radius(&self) -> T { + self.source.radius() + } + + fn arg_radial(&self) -> usize { + self.source.arg_radial() + } + + fn set_arg_radial(&mut self, arg_radial: usize) { + self.source.set_arg_radial(arg_radial); + } + + fn lfd(&self) -> f32 { + self.source.lfd() + } + + fn contains(&self, index: usize) -> bool { + self.source.contains(index) + } + + fn indices(&self) -> Vec { + self.source.indices() + } + + fn set_indices(&mut self, indices: &[usize]) { + self.source.set_indices(indices); + } + + fn extents(&self) -> &[(usize, T)] { + &self.source.extents()[..1] + } + + fn extents_mut(&mut self) -> &mut [(usize, T)] { + &mut self.source.extents_mut()[..1] + } + + fn add_extent(&mut self, idx: usize, extent: T) { + self.source.add_extent(idx, extent); + } + + fn take_extents(&mut self) -> Vec<(usize, T)> { + self.source.take_extents() + } + + fn children(&self) -> Vec<&Self> { + self.children.iter().map(AsRef::as_ref).collect() + } + + fn children_mut(&mut self) -> Vec<&mut Self> { + self.children.iter_mut().map(AsMut::as_mut).collect() + } + + fn set_children(&mut self, children: Vec>) { + self.children = children; + } + + fn take_children(&mut self) -> Vec> { + core::mem::take(&mut self.children) + } + + fn is_descendant_of(&self, other: &Self) -> bool { + self.source.is_descendant_of(&other.source) + } +} + +impl> ParCluster for SquishyBall { + fn par_indices(&self) -> impl ParallelIterator { + self.source.par_indices() + } +} + +/// Parameters for the `OffsetBall`. +#[derive(Debug, Default, Copy, Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +pub struct SquishCosts { + /// Expected memory cost of recursive compression. + recursive: T, + /// Expected memory cost of unitary compression. + unitary: T, + /// The minimum expected memory cost of compression. + minimum: T, +} + +impl, S: Cluster> Params for SquishCosts { + fn child_params>(&self, children: &[S], _: &D, _: &M) -> Vec { + children.iter().map(|_| Self::default()).collect() + } +} + +impl, S: ParCluster> ParParams + for SquishCosts +{ + fn par_child_params>(&self, children: &[S], data: &D, metric: &M) -> Vec { + self.child_params(children, data, metric) + } +} + +impl + Permutable> + BallAdapter, SquishCosts> for SquishyBall> +{ + fn from_ball_tree>(ball: Ball, data: D, metric: &M) -> (Self, CodecData) { + let (off_ball, data) = PermutedBall::from_ball_tree(ball, data, metric); + let mut root = + , _, _>>::adapt_tree_iterative(off_ball, None, &data, metric); + root.set_costs(&data, metric); + root.trim(4); + let data = CodecData::from_compressible(&data, &root); + (root, data) + } +} + +impl + Permutable> + ParBallAdapter, SquishCosts> for SquishyBall> +{ + fn par_from_ball_tree>(ball: Ball, data: D, metric: &M) -> (Self, CodecData) { + let (off_ball, data) = PermutedBall::par_from_ball_tree(ball, data, metric); + let mut root = , _, _>>::par_adapt_tree_iterative( + off_ball, None, &data, metric, + ); + root.par_set_costs(&data, metric); + root.trim(4); + let data = CodecData::par_from_compressible(&data, &root); + (root, data) + } +} + +impl, Dec: Decompressible, S: Cluster> + Adapter, SquishCosts> for SquishyBall +{ + fn new_adapted>( + source: PermutedBall, + children: Vec>, + params: SquishCosts, + _: &Co, + _: &M, + ) -> Self { + Self { + source, + costs: params, + children, + } + } + + fn post_traversal(&mut self) {} + + fn source(&self) -> &PermutedBall { + &self.source + } + + fn source_mut(&mut self) -> &mut PermutedBall { + &mut self.source + } + + fn take_source(self) -> PermutedBall { + self.source + } + + fn params(&self) -> &SquishCosts { + &self.costs + } +} + +impl< + I: Encodable + Decodable + Send + Sync, + T: Number, + Co: ParCompressible, + Dec: ParDecompressible, + S: ParCluster, + > ParAdapter, SquishCosts> for SquishyBall +{ + fn par_new_adapted>( + source: PermutedBall, + children: Vec>, + params: SquishCosts, + _: &Co, + _: &M, + ) -> Self { + Self { + source, + costs: params, + children, + } + } +} + +impl> PartialEq for SquishyBall { + fn eq(&self, other: &Self) -> bool { + self.source == other.source + } +} + +impl> Eq for SquishyBall {} + +impl> PartialOrd for SquishyBall { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl> Ord for SquishyBall { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.source.cmp(&other.source) + } +} + +impl> core::hash::Hash for SquishyBall { + fn hash(&self, state: &mut H) { + self.source.hash(state); + } +} + +#[cfg(feature = "disk-io")] +impl> crate::cluster::Csv for SquishyBall { + fn header(&self) -> Vec { + let mut header = self.source.header(); + header.extend(vec![ + "recursive_cost".to_string(), + "unitary_cost".to_string(), + "minimum_cost".to_string(), + ]); + header + } + + fn row(&self) -> Vec { + let mut row = self.source.row(); + row.pop(); + row.extend(vec![ + self.children.is_empty().to_string(), + self.costs.recursive.to_string(), + self.costs.unitary.to_string(), + self.costs.minimum.to_string(), + ]); + row + } +} + +#[cfg(feature = "disk-io")] +impl> crate::cluster::ParCsv for SquishyBall {} + +#[cfg(feature = "disk-io")] +impl> crate::cluster::ClusterIO for SquishyBall { + fn write_to>(&self, path: &P) -> Result<(), String> + where + Self: bitcode::Encode, + { + let bytes = bitcode::encode(self).map_err(|e| e.to_string())?; + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&bytes).map_err(|e| e.to_string())?; + let bytes = encoder.finish().map_err(|e| e.to_string())?; + std::fs::write(path, bytes).map_err(|e| e.to_string()) + } + + fn read_from>(path: &P) -> Result + where + Self: bitcode::Decode, + { + let mut bytes = Vec::new(); + let mut decoder = GzDecoder::new(std::fs::File::open(path).map_err(|e| e.to_string())?); + decoder.read_to_end(&mut bytes).map_err(|e| e.to_string())?; + bitcode::decode(&bytes).map_err(|e| e.to_string()) + } +} + +#[cfg(feature = "disk-io")] +impl> crate::cluster::ParClusterIO for SquishyBall {} diff --git a/crates/abd-clam/src/pancakes/dataset/codec_data.rs b/crates/abd-clam/src/pancakes/dataset/codec_data.rs new file mode 100644 index 000000000..df427b618 --- /dev/null +++ b/crates/abd-clam/src/pancakes/dataset/codec_data.rs @@ -0,0 +1,531 @@ +//! An implementation of the `Compression` and `Decompression` traits on a +//! `Dataset`. + +use std::collections::HashMap; + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cluster::ParCluster, + dataset::{AssociatesMetadata, AssociatesMetadataMut, ParDataset, Permutable}, + Cluster, Dataset, FlatVec, +}; + +use super::{ + super::SquishyBall, + compression::{Compressible, Encodable, ParCompressible}, + decompression::{Decodable, Decompressible, ParDecompressible}, +}; + +#[cfg(feature = "disk-io")] +use std::io::{Read, Write}; + +#[cfg(feature = "disk-io")] +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; + +/// A compressed dataset, that can be partially decompressed for search and +/// other applications. +/// +/// # Type Parameters +/// +/// - `I`: The type of the items in the dataset. +/// - `Me`: The type of the metadata associated with the items. +#[derive(Clone)] +#[cfg_attr( + feature = "disk-io", + derive(bitcode::Encode, bitcode::Decode, serde::Serialize, serde::Deserialize) +)] +pub struct CodecData { + /// The cardinality of the dataset. + cardinality: usize, + /// A hint for the dimensionality of the dataset. + dimensionality_hint: (usize, Option), + /// The metadata associated with the items. + pub(crate) metadata: Vec, + /// The permutation of the original dataset. + permutation: Vec, + /// The name of the dataset. + name: String, + /// The centers of the clusters in the dataset. + center_map: HashMap, + /// The byte-slices representing the leaf clusters. + leaf_bytes: Vec<(usize, Box<[u8]>)>, +} + +impl CodecData { + /// Creates a `CodecData` from a compressible dataset and a `SquishyBall` tree. + pub fn from_compressible, S: Cluster>(data: &D, root: &SquishyBall) -> Self { + let center_map = root + .subtree() + .into_iter() + .map(Cluster::arg_center) + .map(|i| (i, data.get(i).clone())) + .collect(); + + let leaf_bytes = data + .encode_leaves(root) + .into_iter() + .map(|(leaf, bytes)| (leaf.offset(), bytes)) + .collect(); + + let cardinality = data.cardinality(); + let dimensionality_hint = data.dimensionality_hint(); + + Self { + cardinality, + dimensionality_hint, + metadata: (0..cardinality).collect(), + permutation: (0..cardinality).collect(), + name: format!("CodecData({})", data.name()), + center_map, + leaf_bytes, + } + } +} + +impl CodecData { + /// Parallel version of [`CodecData::from_compressible`](crate::pancakes::dataset::CodecData::from_compressible). + pub fn par_from_compressible, S: ParCluster>( + data: &D, + root: &SquishyBall, + ) -> Self { + let center_map = root + .subtree() + .into_iter() + .map(Cluster::arg_center) + .map(|i| (i, data.get(i).clone())) + .collect(); + + let leaf_bytes = data + .par_encode_leaves(root) + .into_iter() + .map(|(leaf, bytes)| (leaf.offset(), bytes)) + .collect(); + + let cardinality = data.cardinality(); + let dimensionality_hint = data.dimensionality_hint(); + + Self { + cardinality, + dimensionality_hint, + metadata: (0..cardinality).collect(), + permutation: (0..cardinality).collect(), + name: format!("CodecData({})", data.name()), + center_map, + leaf_bytes, + } + } +} + +impl CodecData { + /// Changes the permutation of the dataset without changing the order of the + /// items. + #[must_use] + pub fn with_permutation(self, permutation: &[usize]) -> Self { + Self { + cardinality: self.cardinality, + dimensionality_hint: self.dimensionality_hint, + metadata: self.metadata, + permutation: permutation.to_vec(), + name: self.name, + center_map: self.center_map, + leaf_bytes: self.leaf_bytes, + } + } + + /// Returns the permutation of the original dataset. + #[must_use] + pub fn permutation(&self) -> &[usize] { + &self.permutation + } + + /// Returns the center map of the dataset. + /// + /// This is a map from the index of the center (in the decompressed version) + /// to the center itself. + #[must_use] + pub const fn center_map(&self) -> &HashMap { + &self.center_map + } + + /// Returns the leaf bytes of the dataset. + /// + /// This is an array of tuples of the offset of the leaf cluster and the + /// compressed bytes of the items in the leaf cluster. + #[must_use] + pub fn leaf_bytes(&self) -> &[(usize, Box<[u8]>)] { + &self.leaf_bytes + } + + /// Transforms the centers of the dataset using the given function. + pub fn transform_centers It>(self, transformer: F) -> CodecData { + let center_map = self + .center_map + .into_iter() + .map(|(i, center)| (i, transformer(center))) + .collect(); + + CodecData { + cardinality: self.cardinality, + dimensionality_hint: self.dimensionality_hint, + metadata: self.metadata, + permutation: self.permutation, + name: self.name, + center_map, + leaf_bytes: self.leaf_bytes, + } + } +} + +impl CodecData { + /// Decompresses the dataset into a `FlatVec`. + /// + /// # Errors + /// + /// - If the `FlatVec` cannot be created from the decompressed items. + pub fn to_flat_vec(&self) -> Result, String> { + let items = self + .leaf_bytes + .iter() + .flat_map(|(_, bytes)| self.decode_leaf(bytes.as_ref())) + .collect::>(); + + let (min_dim, max_dim) = self.dimensionality_hint; + + let data = FlatVec::new(items)? + .with_name(&self.name) + .with_metadata(&self.metadata)? + .with_permutation(&self.permutation) + .with_dim_lower_bound(min_dim); + + let data = if let Some(max_dim) = max_dim { + data.with_dim_upper_bound(max_dim) + } else { + data + }; + + Ok(data) + } +} + +impl CodecData { + /// Parallel version of [`CodecData::to_flat_vec`](crate::pancakes::dataset::CodecData::to_flat_vec). + /// + /// # Errors + /// + /// See [`CodecData::to_flat_vec`](crate::pancakes::dataset::CodecData::to_flat_vec). + pub fn par_to_flat_vec(&self) -> Result, String> { + let items = self + .leaf_bytes + .par_iter() + .flat_map(|(_, bytes)| self.decode_leaf(bytes.as_ref())) + .collect::>(); + + let (min_dim, max_dim) = self.dimensionality_hint; + + let data = FlatVec::new(items)? + .with_name(&self.name) + .with_metadata(&self.metadata)? + .with_permutation(&self.permutation) + .with_dim_lower_bound(min_dim); + + let data = if let Some(max_dim) = max_dim { + data.with_dim_upper_bound(max_dim) + } else { + data + }; + + Ok(data) + } +} + +impl Dataset for CodecData { + fn name(&self) -> &str { + &self.name + } + + fn with_name(mut self, name: &str) -> Self { + self.name = format!("CodecData({name})"); + self + } + + fn cardinality(&self) -> usize { + self.cardinality + } + + fn dimensionality_hint(&self) -> (usize, Option) { + self.dimensionality_hint + } + + #[allow(clippy::panic)] + fn get(&self, index: usize) -> &I { + self.center_map.get(&index).map_or_else( + || panic!("For CodecData, the `get` method may only be used for cluster centers."), + |center| center, + ) + } +} + +impl ParDataset for CodecData {} + +impl AssociatesMetadata for CodecData { + fn metadata(&self) -> &[Me] { + &self.metadata + } + + fn metadata_at(&self, index: usize) -> &Me { + &self.metadata[index] + } +} + +impl AssociatesMetadataMut> for CodecData { + fn metadata_mut(&mut self) -> &mut [Me] { + &mut self.metadata + } + + fn metadata_at_mut(&mut self, index: usize) -> &mut Me { + &mut self.metadata[index] + } + + fn with_metadata(self, metadata: &[Met]) -> Result, String> { + if metadata.len() == self.cardinality { + let mut metadata = metadata.to_vec(); + metadata.permute(&self.permutation); + Ok(CodecData { + cardinality: self.cardinality, + dimensionality_hint: self.dimensionality_hint, + metadata, + permutation: self.permutation, + name: self.name, + center_map: self.center_map, + leaf_bytes: self.leaf_bytes, + }) + } else { + Err(format!( + "The length of the metadata vector ({}) does not match the cardinality of the dataset ({}).", + metadata.len(), + self.cardinality + )) + } + } + + fn transform_metadata Met>(self, f: F) -> CodecData { + let metadata = self.metadata.iter().map(f).collect(); + CodecData { + cardinality: self.cardinality, + dimensionality_hint: self.dimensionality_hint, + metadata, + permutation: self.permutation, + name: self.name, + center_map: self.center_map, + leaf_bytes: self.leaf_bytes, + } + } +} + +#[cfg(feature = "disk-io")] +/// Encodes using bitcode and compresses using Gzip. +/// +/// # Errors +/// +/// - If the item cannot be encoded. +/// - If the encoded bytes cannot be compressed. +fn encode_and_compress(item: T) -> Result, String> { + let buf = bitcode::encode(&item).map_err(|e| e.to_string())?; + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&buf).map_err(|e| e.to_string())?; + encoder.finish().map_err(|e| e.to_string()) +} + +#[cfg(feature = "disk-io")] +/// Decompresses using Gzip and decodes using bitcode. +/// +/// # Errors +/// +/// - If the bytes cannot be decompressed. +/// - If the decompressed bytes cannot be decoded. +fn decompress_and_decode(bytes: &[u8]) -> Result { + let mut decoder = GzDecoder::new(bytes); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).map_err(|e| e.to_string())?; + bitcode::decode(&buf).map_err(|e| e.to_string()) +} + +#[cfg(feature = "disk-io")] +impl + crate::dataset::DatasetIO for CodecData +{ + fn write_to>(&self, path: &P) -> Result<(), String> { + let metadata_bytes = encode_and_compress(&self.metadata)?; + + let permutation_bytes = encode_and_compress(&self.permutation)?; + + let center_map = self + .center_map + .iter() + .map(|(&i, p)| (i, p.as_bytes())) + .collect::>(); + let center_map_bytes = encode_and_compress(center_map)?; + + let leaf_bytes = encode_and_compress(&self.leaf_bytes)?; + + let members = ( + self.cardinality, + self.dimensionality_hint, + metadata_bytes, + permutation_bytes, + self.name.clone(), + center_map_bytes, + leaf_bytes, + ); + let bytes = bitcode::encode(&members).map_err(|e| e.to_string())?; + std::fs::write(path, &bytes).map_err(|e| e.to_string()) + } + + fn read_from>(path: &P) -> Result { + let bytes = std::fs::read(path).map_err(|e| e.to_string())?; + + #[allow(clippy::type_complexity)] + let (cardinality, dimensionality_hint, metadata_bytes, permutation_bytes, name, center_map_bytes, leaf_bytes): ( + usize, + (usize, Option), + Vec, + Vec, + String, + Vec, + Vec, + ) = bitcode::decode(&bytes).map_err(|e| e.to_string())?; + + let metadata: Vec = decompress_and_decode(&metadata_bytes)?; + let permutation: Vec = decompress_and_decode(&permutation_bytes)?; + + let center_map: Vec<(usize, Box<[u8]>)> = decompress_and_decode(¢er_map_bytes)?; + let center_map = center_map + .into_iter() + .map(|(i, bytes)| (i, I::from_bytes(&bytes))) + .collect(); + + let leaf_bytes = decompress_and_decode(&leaf_bytes)?; + + Ok(Self { + cardinality, + dimensionality_hint, + metadata, + permutation, + name, + center_map, + leaf_bytes, + }) + } +} + +#[cfg(feature = "disk-io")] +impl< + I: Decodable + bitcode::Encode + bitcode::Decode + Send + Sync, + Me: bitcode::Encode + bitcode::Decode + Send + Sync, + > crate::dataset::ParDatasetIO for CodecData +{ + fn par_write_to>(&self, path: &P) -> Result<(), String> { + let center_map = self + .center_map + .par_iter() + .map(|(&i, p)| (i, p.as_bytes())) + .collect::>(); + + let ((metadata_bytes, center_map_bytes), (permutation_bytes, leaf_bytes)) = rayon::join( + || { + rayon::join( + || encode_and_compress(&self.metadata), + || encode_and_compress(center_map), + ) + }, + || { + rayon::join( + || encode_and_compress(&self.permutation), + || encode_and_compress(&self.leaf_bytes), + ) + }, + ); + + let (metadata_bytes, center_map_bytes, permutation_bytes, leaf_bytes) = + (metadata_bytes?, center_map_bytes?, permutation_bytes?, leaf_bytes?); + + let members = ( + self.cardinality, + self.dimensionality_hint, + metadata_bytes, + permutation_bytes, + self.name.clone(), + center_map_bytes, + leaf_bytes, + ); + + let bytes = bitcode::encode(&members).map_err(|e| e.to_string())?; + std::fs::write(path, &bytes).map_err(|e| e.to_string()) + } + + fn par_read_from>(path: &P) -> Result { + let bytes = std::fs::read(path).map_err(|e| e.to_string())?; + + #[allow(clippy::type_complexity)] + let (cardinality, dimensionality_hint, metadata_bytes, permutation_bytes, name, center_map_bytes, leaf_bytes): ( + usize, + (usize, Option), + Vec, + Vec, + String, + Vec, + Vec, + ) = bitcode::decode(&bytes).map_err(|e| e.to_string())?; + + #[allow(clippy::type_complexity)] + let ((metadata_bytes, center_map_bytes), (permutation, leaf_bytes)): ( + (Result, String>, Result)>, String>), + (Result, String>, Result)>, String>), + ) = rayon::join( + || { + rayon::join( + || decompress_and_decode(&metadata_bytes), + || decompress_and_decode(¢er_map_bytes), + ) + }, + || { + rayon::join( + || decompress_and_decode(&permutation_bytes), + || decompress_and_decode(&leaf_bytes), + ) + }, + ); + + let (metadata, center_map_bytes, permutation, leaf_bytes) = + (metadata_bytes?, center_map_bytes?, permutation?, leaf_bytes?); + + let center_map = center_map_bytes + .into_par_iter() + .map(|(i, bytes)| (i, I::from_bytes(&bytes))) + .collect(); + + Ok(Self { + cardinality, + dimensionality_hint, + metadata, + permutation, + name, + center_map, + leaf_bytes, + }) + } +} + +impl Decompressible for CodecData { + fn centers(&self) -> &HashMap { + &self.center_map + } + + fn leaf_bytes(&self) -> &[(usize, Box<[u8]>)] { + &self.leaf_bytes + } +} + +impl ParDecompressible for CodecData {} diff --git a/crates/abd-clam/src/cakes/codec/compression.rs b/crates/abd-clam/src/pancakes/dataset/compression.rs similarity index 61% rename from crates/abd-clam/src/cakes/codec/compression.rs rename to crates/abd-clam/src/pancakes/dataset/compression.rs index fc02ef8d7..ce2e44f0c 100644 --- a/crates/abd-clam/src/cakes/codec/compression.rs +++ b/crates/abd-clam/src/pancakes/dataset/compression.rs @@ -5,8 +5,10 @@ use rayon::prelude::*; use crate::{cluster::ParCluster, dataset::ParDataset, Cluster, Dataset}; -/// A trait that defines how a value can be encoded in terms of a reference. -pub trait Encodable: Clone { +/// For items can be encoded into a byte array or in terms of a reference. +/// +/// We provide a blanket implementation for all types that implement `Number`. +pub trait Encodable { /// Converts the value to a byte array. fn as_bytes(&self) -> Box<[u8]>; @@ -14,17 +16,14 @@ pub trait Encodable: Clone { fn encode(&self, reference: &Self) -> Box<[u8]>; } -/// A trait that defines how a dataset can be compressed. -pub trait Compressible: Dataset + Sized { - /// Encodes all the instances of leaf clusters in terms of their centers. +/// Given `Encodable` items, a dataset can be compressed. +pub trait Compressible: Dataset { + /// Encodes all the items of leaf clusters in terms of their centers. /// /// # Returns /// - /// - A vector of byte arrays, each containing the encoded instances of a leaf cluster. - fn encode_leaves<'a, D: Dataset, C: Cluster>(&self, root: &'a C) -> Vec<(&'a C, Box<[u8]>)> - where - U: 'a, - { + /// - A vector of byte arrays, each containing the encoded items of a leaf cluster. + fn encode_leaves<'a, T: Number + 'a, C: Cluster>(&self, root: &'a C) -> Vec<(&'a C, Box<[u8]>)> { root.leaves() .into_iter() .map(|leaf| { @@ -43,20 +42,10 @@ pub trait Compressible: Dataset + Sized { } } -/// A trait that defines how a dataset can be compressed. -pub trait ParCompressible: Compressible + ParDataset + Sized { - /// Encodes all the instances of leaf clusters in terms of their centers. - /// - /// # Returns - /// - /// - A flattened vector of encoded instances. - /// - A vector of offsets that indicate the start of the instances for each - /// leaf cluster in the flattened vector. - /// - A vector of cumulative cardinalities of leaves. - fn par_encode_leaves<'a, D: ParDataset, C: ParCluster>(&self, root: &'a C) -> Vec<(&'a C, Box<[u8]>)> - where - U: 'a, - { +/// Parallel version of [`Compressible`](crate::pancakes::dataset::compression::Compressible). +pub trait ParCompressible: Compressible + ParDataset { + /// Parallel version of [`Compressible::encode_leaves`](crate::pancakes::dataset::compression::Compressible::encode_leaves). + fn par_encode_leaves<'a, T: Number + 'a, C: ParCluster>(&self, root: &'a C) -> Vec<(&'a C, Box<[u8]>)> { root.leaves() .into_par_iter() .map(|leaf| { @@ -74,3 +63,13 @@ pub trait ParCompressible: Compressible Encodable for T { + fn as_bytes(&self) -> Box<[u8]> { + self.to_le_bytes().into_boxed_slice() + } + + fn encode(&self, _: &Self) -> Box<[u8]> { + self.as_bytes() + } +} diff --git a/crates/abd-clam/src/cakes/codec/decompression.rs b/crates/abd-clam/src/pancakes/dataset/decompression.rs similarity index 51% rename from crates/abd-clam/src/cakes/codec/decompression.rs rename to crates/abd-clam/src/pancakes/dataset/decompression.rs index 11034d912..3670dd48d 100644 --- a/crates/abd-clam/src/cakes/codec/decompression.rs +++ b/crates/abd-clam/src/pancakes/dataset/decompression.rs @@ -2,12 +2,14 @@ use std::collections::HashMap; -use distances::Number; - use crate::{dataset::ParDataset, Dataset}; -/// A trait that defines how a value can be decoded in terms of a reference. -pub trait Decodable { +use super::Encodable; + +/// For items that can be decoded from a byte array or in terms of a reference. +/// +/// We provide a blanket implementation for all types that implement `Number`. +pub trait Decodable: Encodable { /// Decodes the value from a byte array. fn from_bytes(bytes: &[u8]) -> Self; @@ -15,8 +17,8 @@ pub trait Decodable { fn decode(reference: &Self, bytes: &[u8]) -> Self; } -/// A trait that defines how a dataset can be decompressed. -pub trait Decompressible: Dataset + Sized { +/// Given `Decodable` items, a compressed dataset can be decompressed. +pub trait Decompressible: Dataset { /// Returns the centers of the clusters in the tree associated with this /// dataset. fn centers(&self) -> &HashMap; @@ -25,9 +27,9 @@ pub trait Decompressible: Dataset + Sized { /// bytes. fn leaf_bytes(&self) -> &[(usize, Box<[u8]>)]; - /// Decodes all the instances of a leaf cluster in terms of its center. + /// Decodes all the items of a leaf cluster in terms of its center. fn decode_leaf(&self, bytes: &[u8]) -> Vec { - let mut instances = Vec::new(); + let mut items = Vec::new(); let mut offset = 0; let arg_center = crate::utils::read_number::(bytes, &mut offset); @@ -37,18 +39,28 @@ pub trait Decompressible: Dataset + Sized { for _ in 0..cardinality { let encoding = crate::utils::read_encoding(bytes, &mut offset); - let instance = I::decode(center, &encoding); - instances.push(instance); + let item = I::decode(center, &encoding); + items.push(item); } - instances + items } } -/// Parallel version of the `Decompressible` trait. -pub trait ParDecompressible: Decompressible + ParDataset { - /// Parallel version of the `decode_leaf` method. +/// Parallel version of [`Decompressible`](crate::pancakes::dataset::decompression::Decompressible). +pub trait ParDecompressible: Decompressible + ParDataset { + /// Parallel version of [`Decompressible::decode_leaf`](crate::pancakes::dataset::decompression::Decompressible::decode_leaf). fn par_decode_leaf(&self, bytes: &[u8]) -> Vec { self.decode_leaf(bytes) } } + +impl Decodable for T { + fn from_bytes(bytes: &[u8]) -> Self { + Self::from_le_bytes(bytes) + } + + fn decode(_: &Self, bytes: &[u8]) -> Self { + Self::from_bytes(bytes) + } +} diff --git a/crates/abd-clam/src/pancakes/dataset/mod.rs b/crates/abd-clam/src/pancakes/dataset/mod.rs new file mode 100644 index 000000000..48776b8da --- /dev/null +++ b/crates/abd-clam/src/pancakes/dataset/mod.rs @@ -0,0 +1,112 @@ +//! Datasets for `PanCAKES`. + +use std::collections::HashMap; + +use distances::Number; +use rayon::prelude::*; + +use crate::{ + cakes::{HintedDataset, ParHintedDataset, ParSearchable, Searchable}, + cluster::ParCluster, + metric::ParMetric, + Cluster, Dataset, FlatVec, Metric, +}; + +use super::SquishyBall; + +mod codec_data; +mod compression; +mod decompression; + +pub use codec_data::CodecData; +pub use compression::{Compressible, Encodable, ParCompressible}; +pub use decompression::{Decodable, Decompressible, ParDecompressible}; + +impl Compressible for FlatVec {} +impl ParCompressible for FlatVec {} + +impl, M: Metric, Me> Searchable, M> + for CodecData +{ + fn query_to_center(&self, metric: &M, query: &I, cluster: &SquishyBall) -> T { + metric.distance(query, self.get(cluster.arg_center())) + } + + fn query_to_all(&self, metric: &M, query: &I, cluster: &SquishyBall) -> impl Iterator { + let leaf_bytes = self.leaf_bytes(); + + cluster + .leaves() + .into_iter() + .map(SquishyBall::offset) + .map(|o| { + leaf_bytes + .iter() + .position(|(off, _)| *off == o) + .unwrap_or_else(|| unreachable!("Offset not found in leaf offsets: {o}, {:?}", self.leaf_bytes())) + }) + .map(|pos| &leaf_bytes[pos]) + .flat_map(|(o, bytes)| { + self.decode_leaf(bytes) + .into_iter() + .enumerate() + .map(|(i, p)| (i + *o, p)) + }) + .map(|(i, p)| (i, metric.distance(query, &p))) + } +} + +impl, M: ParMetric, Me: Send + Sync> + ParSearchable, M> for CodecData +{ + fn par_query_to_center(&self, metric: &M, query: &I, cluster: &SquishyBall) -> T { + metric.par_distance(query, self.get(cluster.arg_center())) + } + + fn par_query_to_all( + &self, + metric: &M, + query: &I, + cluster: &SquishyBall, + ) -> impl rayon::prelude::ParallelIterator { + let leaf_bytes = self.leaf_bytes(); + + cluster + .leaves() + .into_par_iter() + .map(SquishyBall::offset) + .map(|o| { + leaf_bytes + .iter() + .position(|(off, _)| *off == o) + .unwrap_or_else(|| unreachable!("Offset not found in leaf offsets: {o}, {:?}", self.leaf_bytes())) + }) + .map(|pos| &leaf_bytes[pos]) + .flat_map(|(o, bytes)| { + self.decode_leaf(bytes) + .into_par_iter() + .enumerate() + .map(|(i, p)| (i + *o, p)) + }) + .map(|(i, p)| (i, metric.par_distance(query, &p))) + } +} + +#[allow(clippy::implicit_hasher)] +impl, M: Metric, Me> HintedDataset, M> + for CodecData)> +{ + fn hints_for(&self, i: usize) -> &HashMap { + &self.metadata[i].1 + } + + fn hints_for_mut(&mut self, i: usize) -> &mut HashMap { + &mut self.metadata[i].1 + } +} + +#[allow(clippy::implicit_hasher)] +impl, M: ParMetric, Me: Send + Sync> + ParHintedDataset, M> for CodecData)> +{ +} diff --git a/crates/abd-clam/src/pancakes/mod.rs b/crates/abd-clam/src/pancakes/mod.rs new file mode 100644 index 000000000..eb5819678 --- /dev/null +++ b/crates/abd-clam/src/pancakes/mod.rs @@ -0,0 +1,8 @@ +//! Compression and Decompression with CLAM + +mod cluster; +mod dataset; +mod sequence; + +pub use cluster::{SquishCosts, SquishyBall}; +pub use dataset::{CodecData, Compressible, Decodable, Decompressible, Encodable, ParCompressible, ParDecompressible}; diff --git a/crates/abd-clam/src/pancakes/sequence.rs b/crates/abd-clam/src/pancakes/sequence.rs new file mode 100644 index 000000000..3d8da23fd --- /dev/null +++ b/crates/abd-clam/src/pancakes/sequence.rs @@ -0,0 +1,178 @@ +//! Implementing `Encodable` and `Decodable` for `Sequence`. + +use distances::Number; + +use crate::msa::{ + ops::{Edit, Edits}, + Sequence, +}; + +use super::{Decodable, Encodable}; + +/// This uses the Needleman-Wunsch algorithm to encode strings. +impl Encodable for Sequence<'_, T> { + fn as_bytes(&self) -> Box<[u8]> { + self.seq().as_bytes().to_vec().into_boxed_slice() + } + + fn encode(&self, reference: &Self) -> Box<[u8]> { + self.aligner().map_or_else( + || self.as_bytes(), + |aligner| { + let table = aligner.dp_table(self, reference); + let [s_to_r, r_to_s] = aligner.edits(self, reference, &table); + + let s_check = apply_edits(self.seq(), &s_to_r); + assert_eq!(s_check, reference.seq(), "From {}, s_to_r: {s_to_r:?}", self.seq()); + + let r_check = apply_edits(reference.seq(), &r_to_s); + assert_eq!(r_check, self.seq(), "From {}, r_to_s: {r_to_s:?}", reference.seq()); + + serialize_edits(&r_to_s) + }, + ) + } +} + +/// This uses the Needleman-Wunsch algorithm to decode strings. +impl Decodable for Sequence<'_, T> { + fn from_bytes(bytes: &[u8]) -> Self { + let seq = + String::from_utf8(bytes.to_vec()).unwrap_or_else(|e| unreachable!("Could not cast back to string: {e:?}")); + Self::new(seq, None) + } + + fn decode(reference: &Self, bytes: &[u8]) -> Self { + let edits = deserialize_edits(bytes); + let seq = apply_edits(reference.seq(), &edits); + Self::new(seq, reference.aligner()) + } +} + +/// Applies a set of edits to a reference (unaligned) string to get a target (unaligned) string. +/// +/// # Arguments +/// +/// * `x`: The unaligned reference string. +/// * `edits`: The edits to apply to the reference string. +/// +/// # Returns +/// +/// The unaligned target string. +#[must_use] +pub fn apply_edits(x: &str, edits: &Edits) -> String { + let mut x: Vec = x.as_bytes().to_vec(); + + for (i, edit) in edits.as_ref() { + match edit { + Edit::Sub(c) => { + x[*i] = *c; + } + Edit::Ins(c) => { + x.insert(*i, *c); + } + Edit::Del => { + x.remove(*i); + } + } + } + + String::from_utf8(x).unwrap_or_else(|e| unreachable!("Could not cast back to string: {e:?}")) +} + +/// Serializes a vector of edit operations into a byte array. +fn serialize_edits(edits: &Edits) -> Box<[u8]> { + let bytes = edits.as_ref().iter().flat_map(edit_to_bin).collect::>(); + bytes.into_boxed_slice() +} + +/// Encodes an edit operation into a byte array. +/// +/// A `Del` edit is encoded as `10` followed by the index of the edit in 14 bits. +/// A `Ins` edit is encoded as `01` followed by the index of the edit in 14 bits and the character in 8 bits. +/// A `Sub` edit is encoded as `11` followed by the index of the edit in 14 bits and the character in 8 bits. +/// +/// # Arguments +/// +/// * `edit`: The edit operation. +/// +/// # Returns +/// +/// A byte array encoding the edit operation. +#[allow(clippy::cast_possible_truncation)] +fn edit_to_bin((i, edit): &(usize, Edit)) -> Vec { + let mask_idx = 0b00_111111; + let mask_del = 0b10_000000; + let mask_ins = 0b01_000000; + let mask_sub = 0b11_000000; + + // First 2 bits for the type of edit, 14 bits for the index. + let mut bytes = (*i as u16).to_be_bytes().to_vec(); + bytes[0] &= mask_idx; + + match edit { + Edit::Del => { + bytes[0] |= mask_del; + } + Edit::Ins(c) => { + bytes[0] |= mask_ins; + // 8 bits for the character. + bytes.push(*c); + } + Edit::Sub(c) => { + bytes[0] |= mask_sub; + // 8 bits for the character. + bytes.push(*c); + } + } + bytes +} + +/// Deserializes a byte array into a vector of edit operations. +/// +/// A `Del` edit is encoded as `10` followed by the index of the edit in 14 bits. +/// A `Ins` edit is encoded as `01` followed by the index of the edit in 14 bits and the character in 8 bits. +/// A `Sub` edit is encoded as `11` followed by the index of the edit in 14 bits and the character in 8 bits. +/// +/// # Arguments +/// +/// * `bytes`: The byte array encoding the edit operations. +/// +/// # Errors +/// +/// * If the byte array is not a valid encoding of edit operations. +/// * If the edit type is not recognized. +/// +/// # Returns +/// +/// A vector of edit operations. +fn deserialize_edits(bytes: &[u8]) -> Edits { + let mut edits = Vec::new(); + let mut offset = 0; + let mask_idx = 0b00_111111; + + while offset < bytes.len() { + let edit_bits = bytes[offset] & !mask_idx; + let i = u16::from_be_bytes([bytes[offset] & mask_idx, bytes[offset + 1]]) as usize; + let edit = match edit_bits { + 0b10_000000 => { + offset += 2; + Edit::Del + } + 0b01_000000 => { + let c = bytes[offset + 2]; + offset += 3; + Edit::Ins(c) + } + 0b11_000000 => { + let c = bytes[offset + 2]; + offset += 3; + Edit::Sub(c) + } + _ => unreachable!("Invalid edit type: {edit_bits:b}."), + }; + edits.push((i, edit)); + } + + Edits::from(edits) +} diff --git a/crates/abd-clam/src/utils.rs b/crates/abd-clam/src/utils.rs index b0710c1df..13acffb32 100644 --- a/crates/abd-clam/src/utils.rs +++ b/crates/abd-clam/src/utils.rs @@ -3,17 +3,78 @@ use core::cmp::Ordering; use distances::{number::Float, Number}; +use rand::prelude::*; + +/// The square root threshold for sub-sampling. +pub(crate) const SQRT_THRESH: usize = 1000; +/// The logarithmic threshold for sub-sampling. +pub(crate) const LOG2_THRESH: usize = 100_000; + +/// Reads the `MAX_RECURSION_DEPTH` environment variable to determine the +/// stride for iterative partition and adaptation. +#[must_use] +pub fn max_recursion_depth() -> usize { + std::env::var("MAX_RECURSION_DEPTH") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(128) +} + +/// Return the number of samples to take from the given population size so as to +/// achieve linear time complexity for geometric median estimation. +/// +/// The number of samples is calculated as follows: +/// +/// - The first `sqrt_thresh` samples are taken from the population. +/// - From the next `log2_thresh` samples, `n` samples are taken, where `n` +/// is the square root of the population size minus `sqrt_thresh`. +/// - From the remaining samples, `n` samples are taken, where `n` is the +/// logarithm base 2 of the population size minus (`log2_thresh` plus +/// `sqrt_thresh`). +#[must_use] +pub fn num_samples(population_size: usize, sqrt_thresh: usize, log2_thresh: usize) -> usize { + if population_size < sqrt_thresh { + population_size + } else { + sqrt_thresh + + if population_size < sqrt_thresh + log2_thresh { + (population_size - sqrt_thresh).as_f64().sqrt() + } else { + log2_thresh.as_f64().sqrt() + (population_size - sqrt_thresh - log2_thresh).as_f64().log2() + } + .as_usize() + } +} + +/// Choose a subset of the given items using the given thresholds. +/// +/// See the `num_samples` function for more information on how the number of +/// samples is calculated. +pub fn choose_samples(indices: &[T], sqrt_thresh: usize, log2_thresh: usize) -> Vec { + let mut indices = indices.to_vec(); + let n = crate::utils::num_samples(indices.len(), sqrt_thresh, log2_thresh); + indices.shuffle(&mut rand::thread_rng()); + indices.truncate(n); + indices +} + +/// Returns the number of distinct pairs that can be formed from `n` elements +/// without repetition. +#[must_use] +pub const fn n_pairs(n: usize) -> usize { + n * (n - 1) / 2 +} /// Return the index and value of the minimum value in the given slice of values. /// /// NAN values are ordered as greater than all other values. /// /// This will return `None` if the given slice is empty. -pub fn arg_min(values: &[T]) -> Option<(usize, T)> { +pub fn arg_min(values: &[T]) -> Option<(usize, T)> { values .iter() .enumerate() - .min_by(|&(_, l), &(_, r)| l.partial_cmp(r).unwrap_or(Ordering::Greater)) + .min_by(|&(_, l), &(_, r)| l.total_cmp(r)) .map(|(i, v)| (i, *v)) } @@ -22,11 +83,11 @@ pub fn arg_min(values: &[T]) -> Option<(usize, T)> { /// NAN values are ordered as smaller than all other values. /// /// This will return `None` if the given slice is empty. -pub fn arg_max(values: &[T]) -> Option<(usize, T)> { +pub fn arg_max(values: &[T]) -> Option<(usize, T)> { values .iter() .enumerate() - .max_by(|&(_, l), &(_, r)| l.partial_cmp(r).unwrap_or(Ordering::Less)) + .max_by(|&(_, l), &(_, r)| l.total_cmp(r)) .map(|(i, v)| (i, *v)) } @@ -85,29 +146,6 @@ pub(crate) fn normalize_1d(values: &[F], mean: F, sd: F) -> Vec { .collect() } -/// Compute the local fractal dimension of the given distances using the given radius. -/// -/// The local fractal dimension is computed as the log2 of the ratio of the number of -/// distances less than or equal to half the radius to the total number of distances. -/// -/// # Arguments -/// -/// * `radius` - The radius used to compute the distances. -/// * `distances` - The distances to compute the local fractal dimension of. -pub(crate) fn compute_lfd(radius: T, distances: &[T]) -> F { - if radius == T::ZERO { - F::ONE - } else { - let r_2 = F::from(radius) / F::from(2); - let half_count = distances.iter().filter(|&&d| F::from(d) <= r_2).count(); - if half_count > 0 { - (F::from(distances.len()) / F::from(half_count)).log2() - } else { - F::ONE - } - } -} - /// Compute the next exponential moving average of the given ratio and parent EMA. /// /// The EMA is computed as `alpha * ratio + (1 - alpha) * parent_ema`, where `alpha` @@ -209,25 +247,18 @@ pub fn calc_row_sds(values: &[Vec; 6]) -> [F; 6] { /// /// * `data` - The data to partition. fn partition(data: &[T]) -> Option<(Vec, T, Vec)> { - if data.is_empty() { - None - } else { - let (pivot_slice, tail) = data.split_at(1); - let pivot = pivot_slice[0]; - let (left, right) = tail.iter().fold((vec![], vec![]), |mut splits, next| { - { - let (ref mut left, ref mut right) = &mut splits; - if next < &pivot { - left.push(*next); - } else { - right.push(*next); - } + data.split_first().map(|(&pivot, tail)| { + let (left, right) = tail.iter().fold((vec![], vec![]), |(mut left, mut right), &next| { + if next < pivot { + left.push(next); + } else { + right.push(next); } - splits + (left, right) }); - Some((left, pivot, right)) - } + (left, pivot, right) + }) } /// A helper function for the median function below. @@ -335,9 +366,9 @@ pub fn un_flatten(data: Vec, sizes: &[usize]) -> Result>, Strin } /// Read a `Number` from a byte slice and increment the offset. -pub fn read_number(bytes: &[u8], offset: &mut usize) -> U { - let num_bytes = U::NUM_BYTES; - let value = U::from_le_bytes( +pub fn read_number(bytes: &[u8], offset: &mut usize) -> T { + let num_bytes = T::NUM_BYTES; + let value = T::from_le_bytes( bytes[*offset..*offset + num_bytes] .try_into() .unwrap_or_else(|e| unreachable!("{e}")), @@ -353,186 +384,3 @@ pub fn read_encoding(bytes: &[u8], offset: &mut usize) -> Box<[u8]> { *offset += len; encoding.into_boxed_slice() } - -#[cfg(test)] -mod tests { - use rand::prelude::*; - use symagen::random_data; - - use super::*; - - #[test] - fn test_transpose() { - // Input data: 3 rows x 6 columns - let data: Vec<[f64; 6]> = vec![ - [2.0, 3.0, 5.0, 7.0, 11.0, 13.0], - [4.0, 3.0, 5.0, 9.0, 10.0, 15.0], - [6.0, 2.0, 8.0, 11.0, 9.0, 11.0], - ]; - - // Expected transposed data: 6 rows x 3 columns - let expected_transposed: [Vec; 6] = [ - vec![2.0, 4.0, 6.0], - vec![3.0, 3.0, 2.0], - vec![5.0, 5.0, 8.0], - vec![7.0, 9.0, 11.0], - vec![11.0, 10.0, 9.0], - vec![13.0, 15.0, 11.0], - ]; - - let transposed_data = rows_to_cols(&data); - - // Check if the transposed data matches the expected result - for i in 0..6 { - assert_eq!(transposed_data[i], expected_transposed[i]); - } - } - - #[test] - fn test_means() { - let all_ratios: Vec<[f64; 6]> = vec![ - [2.0, 4.0, 5.0, 6.0, 9.0, 15.0], - [3.0, 3.0, 6.0, 4.0, 7.0, 10.0], - [5.0, 5.0, 8.0, 8.0, 8.0, 1.0], - ]; - - let transposed = rows_to_cols(&all_ratios); - let means = calc_row_means(&transposed); - - let expected_means: [f64; 6] = [ - 3.333_333_333_333_333_5, - 4.0, - 6.333_333_333_333_334, - 6.0, - 8.0, - 8.666_666_666_666_668, - ]; - - means - .iter() - .zip(expected_means.iter()) - .for_each(|(&a, &b)| assert!(float_cmp::approx_eq!(f64, a, b, ulps = 2), "{a}, {b} not equal")); - } - - #[test] - fn test_sds() { - let all_ratios: Vec<[f64; 6]> = vec![ - [2.0, 4.0, 5.0, 6.0, 9.0, 15.0], - [3.0, 3.0, 6.0, 4.0, 7.0, 10.0], - [5.0, 5.0, 8.0, 8.0, 8.0, 1.0], - ]; - - let expected_standard_deviations: [f64; 6] = [ - 1.247_219_128_924_6, - 0.816_496_580_927_73, - 1.247_219_128_924_6, - 1.632_993_161_855_5, - 0.816_496_580_927_73, - 5.792_715_732_327_6, - ]; - let sds = calc_row_sds(&rows_to_cols(&all_ratios)); - - sds.iter() - .zip(expected_standard_deviations.iter()) - .for_each(|(&a, &b)| { - assert!( - float_cmp::approx_eq!(f64, a, b, epsilon = 0.000_000_03), - "{a}, {b} not equal" - ); - }); - } - - #[test] - fn test_mean_variance() { - // Some synthetic cases to test edge results - let mut test_cases: Vec> = vec![ - vec![0.0], - vec![0.0, 0.0], - vec![1.0], - vec![1.0, 2.0], - vec![0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25], - ]; - - // Use cardinalities of 1, 2, 1000, 100_000 and then 1_000_000 - 10_000_000 in steps of 1_000_000 - let cardinalities = vec![1, 2, 1_000, 100_000] - .into_iter() - .chain((1..=10).map(|i| i * 1_000_000)) - .collect::>(); - - // Ranges for the values generated by SyMaGen - let ranges = vec![ - (-100_000., 0.), - (-10_000., 0.), - (-1_000., 0.), - (0., 1_000.), - (0., 10_000.), - (0., 100_000.), - // These ranges cause the test to fail due to floating point accuracy issues when the sign switches - //(-1_000., 1_000.), - //(-10_000., 10_000.), - //(-100_000., 100_000.) - ]; - - let dimensionality = 1; - let seed = 42; - - // Generate random data for each cardinality and min/max value where max_val > min_val - for (cardinality, (min_val, max_val)) in cardinalities.into_iter().zip(ranges.into_iter()) { - let data = random_data::random_tabular( - dimensionality, - cardinality, - min_val, - max_val, - &mut rand::rngs::StdRng::seed_from_u64(seed), - ) - .into_iter() - .flatten() - .collect::>(); - test_cases.push(data); - } - - let (actual_means, actual_variances): (Vec, Vec) = test_cases - .iter() - .map(|values| mean_variance::(values)) - .unzip(); - - // Calculate expected_means and expected_variances using - // statistical::mean and statistical::population_variance - let expected_means: Vec = test_cases.iter().map(|values| statistical::mean(values)).collect(); - let expected_variances: Vec = test_cases - .iter() - .zip(expected_means.iter()) - .map(|(values, &mean)| statistical::population_variance(values, Some(mean))) - .collect(); - - actual_means.iter().zip(expected_means.iter()).for_each(|(&a, &b)| { - assert!( - float_cmp::approx_eq!(f64, a, b, ulps = 2), - "Means not equal. Actual: {}. Expected: {}. Difference: {}.", - a, - b, - a - b - ); - }); - - actual_variances - .iter() - .zip(expected_variances.iter()) - .for_each(|(&a, &b)| { - assert!( - float_cmp::approx_eq!(f64, a, b, epsilon = 3e-3), - "Variances not equal. Actual: {}. Expected: {}. Difference: {}.", - a, - b, - a - b - ); - }); - } - - #[test] - fn test_standard_deviation() { - let data = [2., 4., 4., 4., 5., 5., 7., 9.]; - let std = standard_deviation::(&data); - assert!((std - 2.).abs() < 1e-6); - } -} diff --git a/crates/abd-clam/tests/ball.rs b/crates/abd-clam/tests/ball.rs new file mode 100644 index 000000000..ade758bbc --- /dev/null +++ b/crates/abd-clam/tests/ball.rs @@ -0,0 +1,229 @@ +//! Tests for the `Ball` struct. + +use abd_clam::{ + cakes::PermutedBall, + cluster::{ + adapter::{BallAdapter, ParBallAdapter}, + BalancedBall, ParPartition, Partition, + }, + metric::{AbsoluteDifference, Manhattan}, + Ball, Cluster, Dataset, FlatVec, Metric, +}; +use distances::{number::Multiplication, Number}; + +mod common; + +#[test] +fn new() { + let data = common::data_gen::gen_tiny_data(); + let metric = Manhattan; + + let indices = (0..data.cardinality()).collect::>(); + let seed = Some(42); + + let root = Ball::new(&data, &metric, &indices, 0, seed).unwrap(); + common::cluster::test_new(&root, &data); + + let root = Ball::par_new(&data, &metric, &indices, 0, seed).unwrap(); + common::cluster::test_new(&root, &data); + + let root = BalancedBall::new(&data, &metric, &indices, 0, seed) + .unwrap() + .into_ball(); + common::cluster::test_new(&root, &data); + + let root = BalancedBall::par_new(&data, &metric, &indices, 0, seed) + .unwrap() + .into_ball(); + common::cluster::test_new(&root, &data); +} + +#[test] +fn tree() { + let data = common::data_gen::gen_tiny_data(); + let metric = Manhattan; + + let seed = Some(42); + let criteria = |c: &Ball<_>| c.depth() < 1; + + let root = Ball::new_tree(&data, &metric, &criteria, seed); + assert_eq!(root.indices().len(), data.cardinality()); + assert!(common::cluster::check_partition(&root)); + + let root = Ball::par_new_tree(&data, &metric, &criteria, seed); + assert_eq!(root.indices().len(), data.cardinality()); + assert!(common::cluster::check_partition(&root)); + + let criteria = |c: &BalancedBall<_>| c.depth() < 1; + let root = BalancedBall::new_tree(&data, &metric, &criteria, seed).into_ball(); + assert_eq!(root.indices().len(), data.cardinality()); + assert!(common::cluster::check_partition(&root)); + + let root = BalancedBall::par_new_tree(&data, &metric, &criteria, seed).into_ball(); + assert_eq!(root.indices().len(), data.cardinality()); + assert!(common::cluster::check_partition(&root)); +} + +#[test] +fn partition_further() { + let data = common::data_gen::gen_tiny_data(); + let metric = Manhattan; + + let seed = Some(42); + let criteria_one = |c: &Ball<_>| c.depth() < 1; + let criteria_two = |c: &Ball<_>| c.depth() < 2; + + let mut root = Ball::new_tree(&data, &metric, &criteria_one, seed); + for leaf in root.leaves() { + assert_eq!(leaf.depth(), 1); + } + root.partition_further(&data, &metric, &criteria_two, seed); + for leaf in root.leaves() { + assert_eq!(leaf.depth(), 2); + } + + let mut root = Ball::par_new_tree(&data, &metric, &criteria_one, seed); + for leaf in root.leaves() { + assert_eq!(leaf.depth(), 1); + } + root.par_partition_further(&data, &metric, &criteria_two, seed); + for leaf in root.leaves() { + assert_eq!(leaf.depth(), 2); + } + + let criteria_one = |c: &BalancedBall<_>| c.depth() < 1; + let criteria_two = |c: &BalancedBall<_>| c.depth() < 2; + + let mut root = BalancedBall::new_tree(&data, &metric, &criteria_one, seed); + for leaf in root.leaves() { + assert_eq!(leaf.depth(), 1); + } + root.partition_further(&data, &metric, &criteria_two, seed); + for leaf in root.leaves() { + assert_eq!(leaf.depth(), 2); + } + + let mut root = BalancedBall::par_new_tree(&data, &metric, &criteria_one, seed); + for leaf in root.leaves() { + assert_eq!(leaf.depth(), 1); + } + root.par_partition_further(&data, &metric, &criteria_two, seed); + for leaf in root.leaves() { + assert_eq!(leaf.depth(), 2); + } +} + +#[test] +fn tree_iterative() { + let data = common::data_gen::gen_pathological_line(); + let metric = AbsoluteDifference; + + let seed = Some(42); + let criteria = |c: &Ball<_>| c.cardinality() > 1; + + let indices = (0..data.cardinality()).collect::>(); + let mut root = Ball::new(&data, &metric, &indices, 0, seed).unwrap(); + + let depth_delta = abd_clam::utils::max_recursion_depth(); + let mut intermediate_depth = depth_delta; + let intermediate_criteria = |c: &Ball<_>| c.depth() < intermediate_depth && criteria(c); + root.partition(&data, &metric, &intermediate_criteria, seed); + + while root.leaves().into_iter().any(|l| !l.is_singleton()) { + intermediate_depth += depth_delta; + let intermediate_criteria = |c: &Ball<_>| c.depth() < intermediate_depth && criteria(c); + root.partition_further(&data, &metric, &intermediate_criteria, seed); + } + + assert!(!root.is_leaf()); + + let criteria = |c: &BalancedBall<_>| c.cardinality() > 1; + let mut root = BalancedBall::new(&data, &metric, &indices, 0, seed).unwrap(); + intermediate_depth = depth_delta; + let intermediate_criteria = |c: &BalancedBall<_>| c.depth() < intermediate_depth && criteria(c); + root.partition(&data, &metric, &intermediate_criteria, seed); + + while root.leaves().into_iter().any(|l| !l.is_singleton()) { + intermediate_depth += depth_delta; + let intermediate_criteria = |c: &BalancedBall<_>| c.depth() < intermediate_depth && criteria(c); + root.partition_further(&data, &metric, &intermediate_criteria, seed); + } + + assert!(!root.is_leaf()); +} + +#[test] +fn trim_and_graft() -> Result<(), String> { + let line = (0..1024).collect(); + let metric = AbsoluteDifference; + let data = FlatVec::new(line)?; + + let seed = Some(42); + let criteria = |c: &Ball<_>| c.cardinality() > 1; + let root = Ball::new_tree(&data, &metric, &criteria, seed); + + let target_depth = 4; + let mut grafted_root = root.clone(); + let children = grafted_root.trim_at_depth(target_depth); + + let leaves = grafted_root.leaves(); + assert_eq!(leaves.len(), 2.powi(target_depth.as_i32())); + assert_eq!(leaves.len(), children.len()); + + grafted_root.graft_at_depth(target_depth, children); + assert_eq!(grafted_root, root); + for (l, c) in root.subtree().into_iter().zip(grafted_root.subtree()) { + assert_eq!(l, c); + } + + let criteria = |c: &BalancedBall<_>| c.cardinality() > 1; + let root = BalancedBall::new_tree(&data, &metric, &criteria, seed); + + let target_depth = 4; + let mut grafted_root = root.clone(); + let children = grafted_root.trim_at_depth(target_depth); + + let leaves = grafted_root.leaves(); + assert_eq!(leaves.len(), 2.powi(target_depth.as_i32())); + assert_eq!(leaves.len(), children.len()); + + grafted_root.graft_at_depth(target_depth, children); + assert_eq!(grafted_root, root); + for (l, c) in root.subtree().into_iter().zip(grafted_root.subtree()) { + assert_eq!(l, c); + } + + Ok(()) +} + +#[test] +fn permutation() { + let data = common::data_gen::gen_tiny_data(); + let metric = Manhattan; + + let seed = Some(42); + let criteria = |c: &Ball<_>| c.depth() < 1; + + let ball = Ball::new_tree(&data, &metric, &criteria, seed); + + let (root, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric); + assert!(check_permutation(&root, &perm_data, &metric)); + + let (root, perm_data) = PermutedBall::par_from_ball_tree(ball, data, &metric); + assert!(check_permutation(&root, &perm_data, &metric)); +} + +fn check_permutation, i32>>( + root: &PermutedBall>, + data: &FlatVec, usize>, + metric: &M, +) -> bool { + assert!(!root.children().is_empty()); + + for cluster in root.subtree() { + let radius = data.one_to_one(cluster.arg_center(), cluster.arg_radial(), metric); + assert_eq!(cluster.radius(), radius); + } + + true +} diff --git a/crates/abd-clam/tests/cakes.rs b/crates/abd-clam/tests/cakes.rs new file mode 100644 index 000000000..453c0a1ba --- /dev/null +++ b/crates/abd-clam/tests/cakes.rs @@ -0,0 +1,261 @@ +//! Tests of the CAKES algorithms. + +use std::collections::HashMap; + +use distances::Number; +use test_case::test_case; + +use abd_clam::{ + cakes::{self, HintedDataset, ParHintedDataset, PermutedBall}, + cluster::{ + adapter::{BallAdapter, ParBallAdapter}, + ParCluster, ParPartition, Partition, + }, + dataset::{AssociatesMetadataMut, Permutable}, + metric::{AbsoluteDifference, Euclidean, Hypotenuse, Levenshtein, Manhattan, ParMetric}, + Ball, Cluster, Dataset, FlatVec, +}; + +mod common; + +#[test] +fn line() { + let data = common::data_gen::gen_line_data(10).transform_metadata(|&i| (i, HashMap::::new())); + let metric = AbsoluteDifference; + let query = &0; + let criteria = |c: &Ball<_>| c.cardinality() > 1; + let seed = Some(42); + + let ball = Ball::new_tree(&data, &metric, &criteria, seed); + let par_ball = Ball::par_new_tree(&data, &metric, &criteria, seed); + + let data = data.with_hints_from(&metric, &par_ball, 2, 4); + let (perm_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric); + + for radius in 0..=4 { + let alg = cakes::RnnClustered(radius); + common::search::check_rnn(&ball, &data, &metric, query, radius, &alg); + common::search::check_rnn(&par_ball, &data, &metric, query, radius, &alg); + } + + for k in [1, 4, 8] { + let alg = cakes::KnnDepthFirst(k); + common::search::check_knn(&ball, &data, &metric, query, k, &alg); + common::search::check_knn(&par_ball, &data, &metric, query, k, &alg); + common::search::check_knn(&perm_ball, &perm_data, &metric, query, k, &alg); + + let alg = cakes::KnnBreadthFirst(k); + common::search::check_knn(&ball, &data, &metric, query, k, &alg); + common::search::check_knn(&par_ball, &data, &metric, query, k, &alg); + common::search::check_knn(&perm_ball, &perm_data, &metric, query, k, &alg); + + let alg = cakes::KnnRepeatedRnn(k, 2); + common::search::check_knn(&ball, &data, &metric, query, k, &alg); + common::search::check_knn(&par_ball, &data, &metric, query, k, &alg); + common::search::check_knn(&perm_ball, &perm_data, &metric, query, k, &alg); + + let alg = cakes::KnnHinted(k); + common::search::check_knn(&ball, &data, &metric, query, k, &alg); + common::search::check_knn(&par_ball, &data, &metric, query, k, &alg); + common::search::check_knn(&perm_ball, &perm_data, &metric, query, k, &alg); + } +} + +#[test] +fn grid() { + let data = common::data_gen::gen_grid_data(10).transform_metadata(|&i| (i, HashMap::new())); + let metric = Hypotenuse; + let query = &(0.0, 0.0); + let criteria = |c: &Ball| c.cardinality() > 1; + let seed = Some(42); + + let ball = Ball::new_tree(&data, &metric, &criteria, seed); + let par_ball = Ball::par_new_tree(&data, &metric, &criteria, seed); + let data = data.with_hints_from(&metric, &par_ball, 2.0, 4); + + let (perm_ball, perm_data) = PermutedBall::from_ball_tree(ball.clone(), data.clone(), &metric); + + for radius in 0..=4 { + let radius = radius.as_f32(); + let alg = cakes::RnnClustered(radius); + common::search::check_rnn(&ball, &data, &metric, query, radius, &alg); + common::search::check_rnn(&par_ball, &data, &metric, query, radius, &alg); + } + + for k in [1, 10] { + let alg = cakes::KnnDepthFirst(k); + common::search::check_knn(&ball, &data, &metric, query, k, &alg); + common::search::check_knn(&par_ball, &data, &metric, query, k, &alg); + common::search::check_knn(&perm_ball, &perm_data, &metric, query, k, &alg); + + let alg = cakes::KnnBreadthFirst(k); + common::search::check_knn(&ball, &data, &metric, query, k, &alg); + common::search::check_knn(&par_ball, &data, &metric, query, k, &alg); + common::search::check_knn(&perm_ball, &perm_data, &metric, query, k, &alg); + + let alg = cakes::KnnRepeatedRnn(k, 2.0); + common::search::check_knn(&ball, &data, &metric, query, k, &alg); + common::search::check_knn(&par_ball, &data, &metric, query, k, &alg); + common::search::check_knn(&perm_ball, &perm_data, &metric, query, k, &alg); + + let alg = cakes::KnnHinted(k); + common::search::check_knn(&ball, &data, &metric, query, k, &alg); + common::search::check_knn(&par_ball, &data, &metric, query, k, &alg); + common::search::check_knn(&perm_ball, &perm_data, &metric, query, k, &alg); + } +} + +#[test_case(1_000, 10)] +#[test_case(10_000, 10)] +#[test_case(1_000, 100)] +#[test_case(10_000, 100)] +fn vectors(car: usize, dim: usize) { + let seed = 42; + let data = common::data_gen::gen_random_data(car, dim, 10.0, seed) + .with_name("random-vectors") + .transform_metadata(|&i| (i, HashMap::new())); + let seed = Some(seed); + let query = vec![0.0; dim]; + let criteria = |c: &Ball<_>| c.cardinality() > 1; + + let radii = [0.01, 0.1]; + let ks = [1, 10]; + + let metrics: Vec, f64>>> = vec![Box::new(Euclidean), Box::new(Manhattan)]; + + for metric in &metrics { + let data = data.clone(); + + let ball = Ball::new_tree(&data, metric, &criteria, seed); + let par_ball = Ball::par_new_tree(&data, metric, &criteria, seed); + let (perm_ball, perm_data) = PermutedBall::par_from_ball_tree(par_ball.clone(), data.clone(), metric); + + let radii = radii + .iter() + .map(|&f| f * ball.radius().as_f32()) + .map(|r| r.as_f64()) + .collect::>(); + + let data = data.with_hints_from_tree(&ball, metric); + + build_and_check_search( + (&ball, &data), + &par_ball, + (&perm_ball, &perm_data), + metric, + &query, + &radii, + &ks, + ); + } +} + +#[test_case(16, 16, 2)] +#[test_case(32, 16, 3)] +fn strings(num_clumps: usize, clump_size: usize, clump_radius: u16) -> Result<(), String> { + let seed_length = 30; + let alphabet = "ACTGN".chars().collect::>(); + let seed_string = symagen::random_edits::generate_random_string(seed_length, &alphabet); + let penalties = distances::strings::Penalties::default(); + let inter_clump_distance_range = (clump_radius * 5, clump_radius * 7); + let len_delta = seed_length / 10; + let (metadata, data) = symagen::random_edits::generate_clumped_data( + &seed_string, + penalties, + &alphabet, + num_clumps, + clump_size, + clump_radius, + inter_clump_distance_range, + len_delta, + ) + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let data = FlatVec::new(data)? + .with_metadata(&metadata)? + .with_name("random-strings") + .transform_metadata(|s| (s.clone(), HashMap::new())); + let query = &seed_string; + let seed = Some(42); + + let radii = [0.01, 0.1]; + let ks = [1, 10]; + + let metric = Levenshtein; + let criteria = |c: &Ball| c.cardinality() > 1; + + let ball = Ball::new_tree(&data, &metric, &criteria, seed); + let par_ball = Ball::par_new_tree(&data, &metric, &criteria, seed); + let (perm_ball, perm_data) = PermutedBall::par_from_ball_tree(par_ball.clone(), data.clone(), &metric); + + let radii = radii + .iter() + .map(|&f| f * ball.radius().as_f32()) + .map(|r| r.as_u32()) + .collect::>(); + + let data = data.with_hints_from_tree(&ball, &metric); + + build_and_check_search( + (&ball, &data), + &par_ball, + (&perm_ball, &perm_data), + &metric, + &query, + &radii, + &ks, + ); + + Ok(()) +} + +/// Build trees and check the search results. +fn build_and_check_search( + ball_data: (&C, &D), + par_ball: &C, + perm_ball_data: (&PermutedBall, &Pd), + metric: &M, + query: &I, + radii: &[T], + ks: &[usize], +) where + I: core::fmt::Debug + Send + Sync + Clone, + T: Number, + C: ParCluster, + M: ParMetric, + D: ParHintedDataset + Permutable + Clone, + Pd: ParHintedDataset, M> + Permutable + Clone, +{ + let (ball, data) = ball_data; + let (perm_ball, perm_data) = perm_ball_data; + + for &radius in radii { + let alg = cakes::RnnClustered(radius); + common::search::check_rnn(ball, data, metric, query, radius, &alg); + common::search::check_rnn(par_ball, data, metric, query, radius, &alg); + common::search::check_rnn(perm_ball, perm_data, metric, query, radius, &alg); + } + + for &k in ks { + let alg = cakes::KnnRepeatedRnn(k, T::ONE.double()); + common::search::check_knn(ball, data, metric, query, k, &alg); + common::search::check_knn(par_ball, data, metric, query, k, &alg); + common::search::check_knn(perm_ball, perm_data, metric, query, k, &alg); + + let alg = cakes::KnnBreadthFirst(k); + common::search::check_knn(ball, data, metric, query, k, &alg); + common::search::check_knn(par_ball, data, metric, query, k, &alg); + common::search::check_knn(perm_ball, perm_data, metric, query, k, &alg); + + let alg = cakes::KnnDepthFirst(k); + common::search::check_knn(ball, data, metric, query, k, &alg); + common::search::check_knn(par_ball, data, metric, query, k, &alg); + common::search::check_knn(perm_ball, perm_data, metric, query, k, &alg); + + let alg = cakes::KnnHinted(k); + common::search::check_knn(ball, data, metric, query, k, &alg); + common::search::check_knn(par_ball, data, metric, query, k, &alg); + common::search::check_knn(perm_ball, perm_data, metric, query, k, &alg); + } +} diff --git a/crates/abd-clam/tests/chaoda.rs b/crates/abd-clam/tests/chaoda.rs new file mode 100644 index 000000000..66154e875 --- /dev/null +++ b/crates/abd-clam/tests/chaoda.rs @@ -0,0 +1,19 @@ +//! Tests for the `chaoda` module. + +use abd_clam::chaoda::roc_auc_score; +use distances::Number; + +#[test] +fn test_roc_auc_score() -> Result<(), String> { + let y_score = (0..100).step_by(10).map(|s| s.as_f32() / 100.0).collect::>(); + + let y_true = y_score.iter().map(|&s| s > 0.5).collect::>(); + let auc = roc_auc_score(&y_true, &y_score)?; + float_cmp::approx_eq!(f32, auc, 1.0); + + let y_true = y_true.into_iter().map(|t| !t).collect::>(); + let auc = roc_auc_score(&y_true, &y_score)?; + float_cmp::approx_eq!(f32, auc, 0.0); + + Ok(()) +} diff --git a/crates/abd-clam/tests/common/cluster.rs b/crates/abd-clam/tests/common/cluster.rs new file mode 100644 index 000000000..338df3eb8 --- /dev/null +++ b/crates/abd-clam/tests/common/cluster.rs @@ -0,0 +1,47 @@ +//! Checking properties of clusters. + +use abd_clam::{Ball, Cluster, Dataset, FlatVec}; + +pub fn test_new(root: &Ball, data: &FlatVec, usize>) { + let arg_r = root.arg_radial(); + let indices = (0..data.cardinality()).collect::>(); + + assert_eq!(arg_r, data.cardinality() - 1); + assert_eq!(root.depth(), 0); + assert_eq!(root.cardinality(), 5); + assert_eq!(root.arg_center(), 2); + assert_eq!(root.radius(), 12); + assert_eq!(root.arg_radial(), arg_r); + assert!(root.children().is_empty()); + assert_eq!(root.indices(), indices); + assert_eq!(root.extents().len(), 1); +} + +pub fn check_partition(root: &Ball) -> bool { + let indices = root.indices(); + + assert!(!root.children().is_empty()); + assert_eq!(indices, &[0, 1, 2, 4, 3]); + assert_eq!(root.extents().len(), 2); + + let children = root.children(); + assert_eq!(children.len(), 2); + for &c in &children { + assert_eq!(c.depth(), 1); + assert!(c.children().is_empty()); + } + + let (left, right) = (children[0], children[1]); + + assert_eq!(left.cardinality(), 3); + assert_eq!(left.arg_center(), 1); + assert_eq!(left.radius(), 4); + assert!([0, 2].contains(&left.arg_radial())); + + assert_eq!(right.cardinality(), 2); + assert_eq!(right.radius(), 8); + assert!([3, 4].contains(&right.arg_center())); + assert!([3, 4].contains(&right.arg_radial())); + + true +} diff --git a/crates/abd-clam/tests/common/data_gen.rs b/crates/abd-clam/tests/common/data_gen.rs new file mode 100644 index 000000000..e544c6c47 --- /dev/null +++ b/crates/abd-clam/tests/common/data_gen.rs @@ -0,0 +1,43 @@ +//! Data generation utilities for testing. + +use abd_clam::FlatVec; +use distances::{number::Float, Number}; +use rand::prelude::*; + +pub fn gen_tiny_data() -> FlatVec, usize> { + let items = vec![vec![1, 2], vec![3, 4], vec![5, 6], vec![7, 8], vec![11, 12]]; + FlatVec::new_array(items).unwrap_or_else(|e| unreachable!("{e}")) +} + +pub fn gen_pathological_line() -> FlatVec { + let min_delta = 1e-12; + let mut delta = min_delta; + let mut line = vec![0_f64]; + + while line.len() < 900 { + let last = *line.last().unwrap_or_else(|| unreachable!()); + line.push(last + delta); + delta *= 2.0; + delta += min_delta; + } + + FlatVec::new(line).unwrap_or_else(|e| unreachable!("{e}")) +} + +pub fn gen_line_data(max: i32) -> FlatVec { + let data = (-max..=max).collect::>(); + FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}")) +} + +pub fn gen_grid_data(max: i32) -> FlatVec<(f32, f32), usize> { + let data = (-max..=max) + .flat_map(|x| (-max..=max).map(move |y| (x.as_f32(), y.as_f32()))) + .collect::>(); + FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}")) +} + +pub fn gen_random_data(car: usize, dim: usize, max: T, seed: u64) -> FlatVec, usize> { + let mut rng = StdRng::seed_from_u64(seed); + let data = symagen::random_data::random_tabular(car, dim, -max, max, &mut rng); + FlatVec::new(data).unwrap_or_else(|e| unreachable!("{e}")) +} diff --git a/crates/abd-clam/tests/common/mod.rs b/crates/abd-clam/tests/common/mod.rs new file mode 100644 index 000000000..f76aa0ef7 --- /dev/null +++ b/crates/abd-clam/tests/common/mod.rs @@ -0,0 +1,7 @@ +//! Common helpers for several tests. +#![allow(dead_code)] + +pub mod cluster; +pub mod data_gen; +pub mod search; +pub mod sequence; diff --git a/crates/abd-clam/tests/common/search.rs b/crates/abd-clam/tests/common/search.rs new file mode 100644 index 000000000..19df01385 --- /dev/null +++ b/crates/abd-clam/tests/common/search.rs @@ -0,0 +1,104 @@ +//! Common functions for testing search algorithms. + +use abd_clam::{ + cakes::{self, ParSearchAlgorithm, ParSearchable, SearchAlgorithm}, + cluster::ParCluster, + metric::ParMetric, +}; +use distances::Number; + +pub fn check_search_by_index(mut true_hits: Vec<(usize, T)>, mut pred_hits: Vec<(usize, T)>, name: &str) { + true_hits.sort_by_key(|(i, _)| *i); + pred_hits.sort_by_key(|(i, _)| *i); + + let rest = format!("\n{true_hits:?}\nvs\n{pred_hits:?}"); + assert_eq!(true_hits.len(), pred_hits.len(), "{name}: {rest}"); + + for ((i, p), (j, q)) in true_hits.into_iter().zip(pred_hits) { + let msg = format!("Failed {name} i: {i}, j: {j}, p: {p}, q: {q}"); + assert_eq!(i, j, "{msg} {rest}"); + assert!(p.abs_diff(q) <= T::EPSILON, "{msg} in {rest}."); + } +} + +pub fn check_search_by_distance(mut true_hits: Vec<(usize, T)>, mut pred_hits: Vec<(usize, T)>, name: &str) { + true_hits.sort_by(|(_, p), (_, q)| p.total_cmp(q)); + pred_hits.sort_by(|(_, p), (_, q)| p.total_cmp(q)); + + assert_eq!( + true_hits.len(), + pred_hits.len(), + "{name}: {true_hits:?} vs {pred_hits:?}" + ); + + for (i, (&(_, p), &(_, q))) in true_hits.iter().zip(pred_hits.iter()).enumerate() { + assert!( + p.abs_diff(q) <= T::EPSILON, + "Failed {name} i-th: {i}, p: {p}, q: {q} in {true_hits:?} vs {pred_hits:?}." + ); + } +} + +pub fn check_rnn(root: &C, data: &D, metric: &M, query: &I, radius: T, alg: &A) +where + I: core::fmt::Debug + Send + Sync, + T: Number, + C: ParCluster, + M: ParMetric, + D: ParSearchable, + A: ParSearchAlgorithm, +{ + let c_name = std::any::type_name::(); + + let true_hits = cakes::RnnLinear(radius).search(data, metric, root, query); + + let pred_hits = alg.search(data, metric, root, query); + assert_eq!( + pred_hits.len(), + true_hits.len(), + "{} search on {c_name} failed: {pred_hits:?}", + alg.name() + ); + check_search_by_index(true_hits.clone(), pred_hits, alg.name()); + + let pred_hits = alg.par_search(data, metric, root, query); + let par_name = format!("Par{}", alg.name()); + assert_eq!( + pred_hits.len(), + true_hits.len(), + "{par_name} search on {c_name} failed: {pred_hits:?}" + ); + check_search_by_index(true_hits, pred_hits, &par_name); +} + +/// Check a k-NN search algorithm. +pub fn check_knn(root: &C, data: &D, metric: &M, query: &I, k: usize, alg: &A) +where + I: core::fmt::Debug + Send + Sync, + T: Number, + C: ParCluster, + M: ParMetric, + D: ParSearchable, + A: ParSearchAlgorithm, +{ + let c_name = std::any::type_name::(); + let true_hits = cakes::KnnLinear(k).search(data, metric, root, query); + + let pred_hits = alg.search(data, metric, root, query); + assert_eq!( + pred_hits.len(), + true_hits.len(), + "{} search on {c_name} failed: pred {pred_hits:?} vs true {true_hits:?}", + alg.name() + ); + check_search_by_distance(true_hits.clone(), pred_hits, alg.name()); + + let pred_hits = alg.par_search(data, metric, root, query); + let par_name = format!("Par{}", alg.name()); + assert_eq!( + pred_hits.len(), + true_hits.len(), + "{par_name} search on {c_name} failed: pred {pred_hits:?} vs true {true_hits:?}" + ); + check_search_by_distance(true_hits, pred_hits, &par_name); +} diff --git a/crates/abd-clam/tests/common/sequence.rs b/crates/abd-clam/tests/common/sequence.rs new file mode 100644 index 000000000..0cb49c351 --- /dev/null +++ b/crates/abd-clam/tests/common/sequence.rs @@ -0,0 +1 @@ +//! Encodable and decodable sequences. diff --git a/crates/abd-clam/tests/flat_vec.rs b/crates/abd-clam/tests/flat_vec.rs new file mode 100644 index 000000000..dc1c2f983 --- /dev/null +++ b/crates/abd-clam/tests/flat_vec.rs @@ -0,0 +1,159 @@ +//! Tests for the `FlatVec` struct. + +use abd_clam::{ + dataset::{AssociatesMetadata, Permutable}, + Dataset, FlatVec, +}; + +#[test] +fn creation() -> Result<(), String> { + let items = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; + + let dataset = FlatVec::new(items.clone())?; + assert_eq!(dataset.cardinality(), 3); + assert_eq!(dataset.dimensionality_hint(), (0, None)); + + let dataset = FlatVec::new_array(items)?; + assert_eq!(dataset.cardinality(), 3); + assert_eq!(dataset.dimensionality_hint(), (2, Some(2))); + + Ok(()) +} + +#[test] +fn ser_de() -> Result<(), String> { + type Fv = FlatVec, usize>; + + let items = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; + let dataset: Fv = FlatVec::new_array(items)?; + + let serialized: Vec = bitcode::encode(&dataset).map_err(|e| e.to_string())?; + let deserialized: Fv = bitcode::decode(&serialized).map_err(|e| e.to_string())?; + + assert_eq!(dataset.cardinality(), deserialized.cardinality()); + assert_eq!(dataset.dimensionality_hint(), deserialized.dimensionality_hint()); + assert_eq!(dataset.permutation(), deserialized.permutation()); + assert_eq!(dataset.metadata(), deserialized.metadata()); + for i in 0..dataset.cardinality() { + assert_eq!(dataset.get(i), deserialized.get(i)); + } + + Ok(()) +} + +#[test] +fn permutations() -> Result<(), String> { + struct SwapTracker { + data: FlatVec, usize>, + count: usize, + } + + impl Dataset> for SwapTracker { + fn name(&self) -> &str { + "SwapTracker" + } + + fn with_name(mut self, name: &str) -> Self { + self.data = self.data.with_name(name); + self + } + + fn cardinality(&self) -> usize { + self.data.cardinality() + } + + fn dimensionality_hint(&self) -> (usize, Option) { + self.data.dimensionality_hint() + } + + fn get(&self, index: usize) -> &Vec { + self.data.get(index) + } + } + + impl Permutable for SwapTracker { + fn permutation(&self) -> Vec { + self.data.permutation() + } + + fn set_permutation(&mut self, permutation: &[usize]) { + self.data.set_permutation(permutation); + } + + fn swap_two(&mut self, i: usize, j: usize) { + self.data.swap_two(i, j); + self.count += 1; + } + } + + let items = vec![ + vec![1, 2], + vec![3, 4], + vec![5, 6], + vec![7, 8], + vec![9, 10], + vec![11, 12], + ]; + let data = FlatVec::new_array(items.clone())?; + let mut swap_tracker = SwapTracker { data, count: 0 }; + + swap_tracker.swap_two(0, 2); + assert_eq!(swap_tracker.permutation(), &[2, 1, 0, 3, 4, 5]); + assert_eq!(swap_tracker.count, 1); + for (i, &j) in swap_tracker.permutation().iter().enumerate() { + assert_eq!(swap_tracker.get(i), &items[j]); + } + + swap_tracker.swap_two(0, 4); + assert_eq!(swap_tracker.permutation(), &[4, 1, 0, 3, 2, 5]); + assert_eq!(swap_tracker.count, 2); + for (i, &j) in swap_tracker.permutation().iter().enumerate() { + assert_eq!(swap_tracker.get(i), &items[j]); + } + + let data = FlatVec::new_array(items.clone())?; + let mut data = SwapTracker { data, count: 0 }; + let permutation = vec![2, 1, 0, 5, 4, 3]; + data.permute(&permutation); + assert_eq!(data.permutation(), permutation); + assert_eq!(data.count, 2); + for (i, &j) in data.permutation().iter().enumerate() { + assert_eq!(data.get(i), &items[j]); + } + + Ok(()) +} + +#[cfg(feature = "disk-io")] +#[test] +fn npy_io() -> Result<(), String> { + let items = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; + let dataset = FlatVec::new_array(items)?; + + let tmp_dir = tempdir::TempDir::new("testing").map_err(|e| e.to_string())?; + let path = tmp_dir.path().join("test.npy"); + dataset.write_npy(&path)?; + + let new_dataset = FlatVec::, _>::read_npy(&path)?; + assert_eq!(new_dataset.cardinality(), 3); + assert_eq!(new_dataset.dimensionality_hint(), (2, Some(2))); + for i in 0..dataset.cardinality() { + assert_eq!(dataset.get(i), new_dataset.get(i)); + } + + let new_dataset = FlatVec::, _>::read_npy(&path)?; + assert_eq!(new_dataset.cardinality(), 3); + assert_eq!(new_dataset.dimensionality_hint(), (2, Some(2))); + for i in 0..dataset.cardinality() { + assert_eq!(dataset.get(i), new_dataset.get(i)); + } + + let new_dataset = FlatVec::, _>::read_npy(&path)?; + assert_eq!(new_dataset.cardinality(), 3); + assert_eq!(new_dataset.dimensionality_hint(), (2, Some(2))); + for i in 0..dataset.cardinality() { + assert_eq!(dataset.get(i), new_dataset.get(i)); + } + + Ok(()) +} diff --git a/crates/abd-clam/tests/needleman_wunsch.rs b/crates/abd-clam/tests/needleman_wunsch.rs new file mode 100644 index 000000000..5969d4d93 --- /dev/null +++ b/crates/abd-clam/tests/needleman_wunsch.rs @@ -0,0 +1,104 @@ +//! Tests for the Needleman-Wunsch aligner in `abd-clam::msa`. + +use abd_clam::msa::{ops::Direction, Aligner, CostMatrix}; + +#[test] +fn distance() { + let matrix = CostMatrix::default(); + let nw_aligner = Aligner::::new(&matrix, b'-'); + + let x = "NAJIBEATSPEPPERS"; + let y = "NAJIBPEPPERSEATS"; + assert_eq!(nw_aligner.distance(&nw_aligner.dp_table(&x, &y)), 8); + assert_eq!(nw_aligner.distance(&nw_aligner.dp_table(&y, &x)), 8); + + let x = "NOTGUILTY".to_string(); + let y = "NOTGUILTY".to_string(); + assert_eq!(nw_aligner.distance(&nw_aligner.dp_table(&x, &y)), 0); + assert_eq!(nw_aligner.distance(&nw_aligner.dp_table(&y, &x)), 0); +} + +#[test] +fn test_compute_table() { + let x = "NAJIBPEPPERSEATS"; + let y = "NAJIBEATSPEPPERS"; + let matrix = CostMatrix::default(); + let nw_aligner = Aligner::new(&matrix, b'-'); + let table = nw_aligner.dp_table(&x, &y); + + #[rustfmt::skip] + let true_table: [[(i16, Direction); 17]; 17] = [ + [( 0, Direction::Diagonal), ( 1, Direction::Left ), ( 2, Direction::Left ), ( 3, Direction::Left ), ( 4, Direction::Left ), ( 5, Direction::Left ), ( 6, Direction::Left ), (7, Direction::Left ), (8, Direction::Left ), (9, Direction::Left ), (10, Direction::Left ), (11, Direction::Left ), (12, Direction::Left ), (13, Direction::Left ), (14, Direction::Left ), (15, Direction::Left ), (16, Direction::Left )], + [( 1, Direction::Up ), ( 0, Direction::Diagonal), ( 1, Direction::Left ), ( 2, Direction::Left ), ( 3, Direction::Left ), ( 4, Direction::Left ), ( 5, Direction::Left ), (6, Direction::Left ), (7, Direction::Left ), (8, Direction::Left ), ( 9, Direction::Left ), (10, Direction::Left ), (11, Direction::Left ), (12, Direction::Left ), (13, Direction::Left ), (14, Direction::Left ), (15, Direction::Left )], + [( 2, Direction::Up ), ( 1, Direction::Up ), ( 0, Direction::Diagonal), ( 1, Direction::Left ), ( 2, Direction::Left ), ( 3, Direction::Left ), ( 4, Direction::Left ), (5, Direction::Left ), (6, Direction::Left ), (7, Direction::Left ), ( 8, Direction::Left ), ( 9, Direction::Left ), (10, Direction::Left ), (11, Direction::Left ), (12, Direction::Diagonal), (13, Direction::Left ), (14, Direction::Left )], + [( 3, Direction::Up ), ( 2, Direction::Up ), ( 1, Direction::Up ), ( 0, Direction::Diagonal), ( 1, Direction::Left ), ( 2, Direction::Left ), ( 3, Direction::Left ), (4, Direction::Left ), (5, Direction::Left ), (6, Direction::Left ), ( 7, Direction::Left ), ( 8, Direction::Left ), ( 9, Direction::Left ), (10, Direction::Left ), (11, Direction::Left ), (12, Direction::Left ), (13, Direction::Left )], + [( 4, Direction::Up ), ( 3, Direction::Up ), ( 2, Direction::Up ), ( 1, Direction::Up ), ( 0, Direction::Diagonal), ( 1, Direction::Left ), ( 2, Direction::Left ), (3, Direction::Left ), (4, Direction::Left ), (5, Direction::Left ), ( 6, Direction::Left ), ( 7, Direction::Left ), ( 8, Direction::Left ), ( 9, Direction::Left ), (10, Direction::Left ), (11, Direction::Left ), (12, Direction::Left )], + [( 5, Direction::Up ), ( 4, Direction::Up ), ( 3, Direction::Up ), ( 2, Direction::Up ), ( 1, Direction::Up ), ( 0, Direction::Diagonal), ( 1, Direction::Left ), (2, Direction::Left ), (3, Direction::Left ), (4, Direction::Left ), ( 5, Direction::Left ), ( 6, Direction::Left ), ( 7, Direction::Left ), ( 8, Direction::Left ), ( 9, Direction::Left ), (10, Direction::Left ), (11, Direction::Left )], + [( 6, Direction::Up ), ( 5, Direction::Up ), ( 4, Direction::Up ), ( 3, Direction::Up ), ( 2, Direction::Up ), ( 1, Direction::Up ), ( 1, Direction::Diagonal), (1, Direction::Diagonal), (2, Direction::Left ), (3, Direction::Left ), ( 4, Direction::Diagonal), ( 5, Direction::Left ), ( 6, Direction::Left ), ( 7, Direction::Diagonal), ( 8, Direction::Left ), ( 9, Direction::Left ), (10, Direction::Left )], + [( 7, Direction::Up ), ( 6, Direction::Up ), ( 5, Direction::Diagonal), ( 4, Direction::Up ), ( 3, Direction::Up ), ( 2, Direction::Up ), ( 2, Direction::Diagonal), (2, Direction::Diagonal), (2, Direction::Diagonal), (3, Direction::Diagonal), ( 4, Direction::Diagonal), ( 5, Direction::Diagonal), ( 6, Direction::Diagonal), ( 7, Direction::Diagonal), ( 7, Direction::Diagonal), ( 8, Direction::Left ), ( 9, Direction::Left )], + [( 8, Direction::Up ), ( 7, Direction::Up ), ( 6, Direction::Up ), ( 5, Direction::Up ), ( 4, Direction::Up ), ( 3, Direction::Up ), ( 3, Direction::Diagonal), (3, Direction::Diagonal), (3, Direction::Diagonal), (3, Direction::Diagonal), ( 4, Direction::Diagonal), ( 5, Direction::Diagonal), ( 6, Direction::Diagonal), ( 7, Direction::Diagonal), ( 8, Direction::Diagonal), ( 7, Direction::Diagonal), ( 8, Direction::Left )], + [( 9, Direction::Up ), ( 8, Direction::Up ), ( 7, Direction::Up ), ( 6, Direction::Up ), ( 5, Direction::Up ), ( 4, Direction::Up ), ( 4, Direction::Diagonal), (4, Direction::Diagonal), (4, Direction::Diagonal), (4, Direction::Diagonal), ( 4, Direction::Diagonal), ( 5, Direction::Diagonal), ( 5, Direction::Diagonal), ( 6, Direction::Left ), ( 7, Direction::Left ), ( 8, Direction::Up ), ( 7, Direction::Diagonal)], + [(10, Direction::Up ), ( 9, Direction::Up ), ( 8, Direction::Up ), ( 7, Direction::Up ), ( 6, Direction::Up ), ( 5, Direction::Up ), ( 4, Direction::Diagonal), (5, Direction::Diagonal), (4, Direction::Diagonal), (4, Direction::Diagonal), ( 5, Direction::Diagonal), ( 5, Direction::Diagonal), ( 6, Direction::Diagonal), ( 6, Direction::Diagonal), ( 7, Direction::Diagonal), ( 8, Direction::Diagonal), ( 8, Direction::Up )], + [(11, Direction::Up ), (10, Direction::Up ), ( 9, Direction::Up ), ( 8, Direction::Up ), ( 7, Direction::Up ), ( 6, Direction::Up ), ( 5, Direction::Up ), (4, Direction::Diagonal), (5, Direction::Up ), (5, Direction::Diagonal), ( 4, Direction::Diagonal), ( 5, Direction::Left ), ( 6, Direction::Diagonal), ( 6, Direction::Diagonal), ( 7, Direction::Diagonal), ( 8, Direction::Diagonal), ( 9, Direction::Diagonal)], + [(12, Direction::Up ), (11, Direction::Up ), (10, Direction::Up ), ( 9, Direction::Up ), ( 8, Direction::Up ), ( 7, Direction::Up ), ( 6, Direction::Diagonal), (5, Direction::Up ), (4, Direction::Diagonal), (5, Direction::Diagonal), ( 5, Direction::Up ), ( 5, Direction::Diagonal), ( 6, Direction::Diagonal), ( 7, Direction::Diagonal), ( 7, Direction::Diagonal), ( 8, Direction::Diagonal), ( 9, Direction::Diagonal)], + [(13, Direction::Up ), (12, Direction::Up ), (11, Direction::Up ), (10, Direction::Up ), ( 9, Direction::Up ), ( 8, Direction::Up ), ( 7, Direction::Diagonal), (6, Direction::Up ), (5, Direction::Diagonal), (4, Direction::Diagonal), ( 5, Direction::Left ), ( 6, Direction::Diagonal), ( 6, Direction::Diagonal), ( 7, Direction::Diagonal), ( 8, Direction::Diagonal), ( 8, Direction::Diagonal), ( 9, Direction::Diagonal)], + [(14, Direction::Up ), (13, Direction::Up ), (12, Direction::Up ), (11, Direction::Up ), (10, Direction::Up ), ( 9, Direction::Up ), ( 8, Direction::Up ), (7, Direction::Diagonal), (6, Direction::Up ), (5, Direction::Up ), ( 4, Direction::Diagonal), ( 5, Direction::Left ), ( 6, Direction::Left ), ( 6, Direction::Diagonal), ( 7, Direction::Left ), ( 8, Direction::Left ), ( 9, Direction::Diagonal)], + [(15, Direction::Up ), (14, Direction::Up ), (13, Direction::Up ), (12, Direction::Up ), (11, Direction::Up ), (10, Direction::Up ), ( 9, Direction::Up ), (8, Direction::Up ), (7, Direction::Up ), (6, Direction::Up ), ( 5, Direction::Up ), ( 4, Direction::Diagonal), ( 5, Direction::Left ), ( 6, Direction::Left ), ( 7, Direction::Diagonal), ( 8, Direction::Diagonal), ( 9, Direction::Diagonal)], + [(16, Direction::Up ), (15, Direction::Up ), (14, Direction::Up ), (13, Direction::Up ), (12, Direction::Up ), (11, Direction::Up ), (10, Direction::Up ), (9, Direction::Up ), (8, Direction::Up ), (7, Direction::Up ), ( 6, Direction::Up ), ( 5, Direction::Up ), ( 4, Direction::Diagonal), ( 5, Direction::Left ), ( 6, Direction::Left ), ( 7, Direction::Left ), ( 8, Direction::Diagonal)] + ]; + + assert_eq!(table, true_table); +} + +#[test] +fn test_trace_back() { + let matrix = CostMatrix::default(); + let nw_aligner = Aligner::::new(&matrix, b'-'); + + let peppers_x = "NAJIBPEPPERSEATS"; + let peppers_y = "NAJIBEATSPEPPERS"; + let peppers_table = nw_aligner.dp_table(&peppers_x, &peppers_y); + + let d = nw_aligner.distance(&peppers_table); + assert_eq!(d, 8); + + let [aligned_x, aligned_y] = nw_aligner.align_str(&peppers_x, &peppers_y, &peppers_table); + assert_eq!(aligned_x, "NAJIB-PEPPERSEATS"); + assert_eq!(aligned_y, "NAJIBEATSPEPPE-RS"); + + let guilty_x = "NOTGUILTY"; + let guilty_y = "NOTGUILTY"; + let guilty_table = nw_aligner.dp_table(&guilty_x, &guilty_y); + + let d = nw_aligner.distance(&guilty_table); + assert_eq!(d, 0); + + let [aligned_x, aligned_y] = nw_aligner.align_str(&guilty_x, &guilty_y, &guilty_table); + assert_eq!(aligned_x, "NOTGUILTY"); + assert_eq!(aligned_y, "NOTGUILTY"); +} + +#[test] +fn test_alignment_gaps() { + let matrix = CostMatrix::default(); + let nw_aligner = Aligner::::new(&matrix, b'-'); + + let x = "MDIAIHHPWIRRP---"; + let y = "MDIAIHHPWIRRPF"; + let table = nw_aligner.dp_table(&x, &y); + + let d = nw_aligner.distance(&table); + assert_eq!(d, 3); + + let [x_gaps, y_gaps] = nw_aligner.alignment_gaps(&x, &y, &table); + assert!(x_gaps.is_empty()); + assert_eq!(y_gaps, vec![13, 13]); + + let table = nw_aligner.dp_table(&y, &x); + let d = nw_aligner.distance(&table); + assert_eq!(d, 3); + + let [x_gaps, y_gaps] = nw_aligner.alignment_gaps(&y, &x, &table); + assert_eq!(x_gaps, vec![13, 13]); + assert!(y_gaps.is_empty()); +} diff --git a/crates/abd-clam/tests/pancakes.rs b/crates/abd-clam/tests/pancakes.rs new file mode 100644 index 000000000..63ac9620a --- /dev/null +++ b/crates/abd-clam/tests/pancakes.rs @@ -0,0 +1,142 @@ +//! Tests for the `pancakes` module. + +use distances::Number; +use test_case::test_case; + +use abd_clam::{ + cakes::{KnnBreadthFirst, KnnDepthFirst, KnnRepeatedRnn, PermutedBall, RnnClustered}, + cluster::{ + adapter::{Adapter, BallAdapter, ParAdapter}, + ParPartition, Partition, + }, + dataset::{AssociatesMetadata, AssociatesMetadataMut, Permutable}, + metric::{AbsoluteDifference, Levenshtein, ParMetric}, + msa::{Aligner, CostMatrix, Sequence}, + pancakes::{CodecData, Decodable, Encodable, ParCompressible, SquishyBall}, + Ball, Cluster, FlatVec, +}; + +mod common; + +#[test_case(16, 16, 2)] +#[test_case(32, 16, 3)] +fn strings(num_clumps: usize, clump_size: usize, clump_radius: u16) -> Result<(), String> { + let matrix = CostMatrix::::default_affine(Some(10)); + let aligner = Aligner::new(&matrix, b'-'); + + let seed_length = 30; + let alphabet = "ACTGN".chars().collect::>(); + let seed_string = symagen::random_edits::generate_random_string(seed_length, &alphabet); + let penalties = distances::strings::Penalties::default(); + let inter_clump_distance_range = (clump_radius * 5, clump_radius * 7); + let len_delta = seed_length / 10; + let (metadata, data) = symagen::random_edits::generate_clumped_data( + &seed_string, + penalties, + &alphabet, + num_clumps, + clump_size, + clump_radius, + inter_clump_distance_range, + len_delta, + ) + .into_iter() + .map(|(m, seq)| (m, Sequence::new(seq, Some(&aligner)))) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let data = FlatVec::new(data)?.with_metadata(&metadata)?; + let query = Sequence::new(seed_string.clone(), Some(&aligner)); + let seed = Some(42); + + let radii = [1, 4, 8]; + let ks = [1, 10, 20]; + + build_and_check_search(&data, &Levenshtein, &query, seed, &radii, &ks); + + Ok(()) +} + +#[cfg(feature = "disk-io")] +#[test] +fn ser_de() -> Result<(), String> { + use abd_clam::Dataset; + + // The items. + type I = i32; + // The distance values. + type U = i32; + // The compressible dataset + type Co = FlatVec; + // The ball for the compressible dataset. + type B = Ball; + // The decompressible dataset + type Dec = CodecData; + // The squishy ball + type Sb = SquishyBall; + + let data: Co = common::data_gen::gen_line_data(100); + let metric = AbsoluteDifference; + let metadata = data.metadata().to_vec(); + + let criteria = |c: &B| c.cardinality() > 1; + let ball = B::new_tree(&data, &metric, &criteria, Some(42)); + let (_, co_data) = Sb::from_ball_tree(ball, data, &metric); + let co_data = co_data.with_metadata(&metadata)?; + + let serialized = bitcode::encode(&co_data).map_err(|e| e.to_string())?; + let deserialized: Dec = bitcode::decode(&serialized).map_err(|e| e.to_string())?; + + assert_eq!(co_data.cardinality(), deserialized.cardinality()); + assert_eq!(co_data.dimensionality_hint(), deserialized.dimensionality_hint()); + assert_eq!(co_data.metadata(), deserialized.metadata()); + assert_eq!(co_data.permutation(), deserialized.permutation()); + assert_eq!(co_data.center_map(), deserialized.center_map()); + assert_eq!(co_data.leaf_bytes(), deserialized.leaf_bytes()); + + Ok(()) +} + +/// Build trees and check the search results. +fn build_and_check_search(data: &D, metric: &M, query: &I, seed: Option, radii: &[T], ks: &[usize]) +where + I: core::fmt::Debug + Send + Sync + Encodable + Decodable + Clone + Eq, + T: Number, + D: ParCompressible + Permutable + Clone, + M: ParMetric, +{ + let criterion = |c: &Ball| c.cardinality() > 1; + + let (ball, data) = { + let ball = Ball::par_new_tree(data, metric, &criterion, seed); + + let mut perm_data = data.clone(); + let mut perm_ball = PermutedBall::par_adapt_tree_iterative(ball, None, &perm_data, metric); + + let permutation = as Adapter>::source(&perm_ball).indices(); + perm_data.permute(&permutation); + perm_ball.clear_source_indices(); + + let mut squishy_ball = + as ParAdapter<_, _, _, CodecData, _, _>>::par_adapt_tree_iterative( + perm_ball, None, &perm_data, metric, + ); + + squishy_ball.par_set_costs(&perm_data, metric); + squishy_ball.trim(4); + + let co_data = CodecData::par_from_compressible(&perm_data, &squishy_ball); + + (squishy_ball, co_data) + }; + + for &radius in radii { + let alg = RnnClustered(radius); + common::search::check_rnn(&ball, &data, metric, query, radius, &alg); + } + + for &k in ks { + common::search::check_knn(&ball, &data, metric, query, k, &KnnRepeatedRnn(k, T::ONE.double())); + common::search::check_knn(&ball, &data, metric, query, k, &KnnBreadthFirst(k)); + common::search::check_knn(&ball, &data, metric, query, k, &KnnDepthFirst(k)); + } +} diff --git a/crates/abd-clam/tests/tree.rs b/crates/abd-clam/tests/tree.rs new file mode 100644 index 000000000..b0b797d2d --- /dev/null +++ b/crates/abd-clam/tests/tree.rs @@ -0,0 +1,118 @@ +//! Tests for the `Tree` struct. + +use abd_clam::{cluster::Partition, metric::Euclidean, Ball, Cluster, Tree}; +use test_case::test_case; + +mod common; + +#[test_case(20, 2)] +#[test_case(1_000, 10)] +#[test_case(10_000, 10)] +fn from_root(car: usize, dim: usize) { + let max = 1.0; + let seed = 42; + + let data = common::data_gen::gen_random_data(car, dim, max, seed); + let metric = Euclidean; + let criteria = |c: &Ball<_>| c.cardinality() > 1; + + let root = Ball::new_tree(&data, &metric, &criteria, Some(seed)); + let true_subtree = { + let mut subtree = root.subtree(); + subtree.sort(); + subtree + }; + + // let tree = Tree::new(&data, &metric, &criteria, Some(seed)); + let tree = Tree::from(root.clone()); + let tree_subtree = { + let mut subtree = tree.bft().collect::>(); + subtree.sort(); + subtree + }; + + assert_eq!(true_subtree.len(), tree_subtree.len()); + for (&a, &b) in true_subtree.iter().zip(tree_subtree.iter()) { + check_ball_eq(a, b); + } + + for a in true_subtree { + let res = tree.find(a); + assert!(res.is_some()); + let (depth, index, _, _) = res.unwrap(); + + let b = tree.get(depth, index); + assert!(b.is_some()); + + let b = b.unwrap(); + check_ball_eq(a, b); + + let children = tree.children_of(depth, index); + assert_eq!(a.children().len(), children.len()); + + for (&c, &(d, _, _)) in a.children().iter().zip(children.iter()) { + check_ball_eq(c, d); + } + } +} + +#[test_case(20, 2)] +#[test_case(1_000, 10)] +#[test_case(10_000, 10)] +fn new(car: usize, dim: usize) { + let max = 1.0; + let seed = 42; + + let data = common::data_gen::gen_random_data(car, dim, max, seed); + let metric = Euclidean; + let criteria = |c: &Ball<_>| c.cardinality() > 1; + + let tree = Tree::new(&data, &metric, &criteria, Some(seed)).unwrap(); + check_tree(&tree); + + let tree = Tree::par_new(&data, &metric, &criteria, Some(seed)).unwrap(); + check_tree(&tree); +} + +fn check_ball_eq(a: &Ball, b: &Ball) { + assert_eq!(a.depth(), b.depth()); + assert_eq!(a.cardinality(), b.cardinality()); + assert_eq!(a.radius(), b.radius()); + assert_eq!(a.lfd(), b.lfd()); + assert_eq!(a.arg_center(), b.arg_center()); + assert_eq!(a.arg_radial(), b.arg_radial()); + assert_eq!(a.indices(), b.indices()); +} + +fn check_tree(tree: &Tree>) { + let subtree = tree.bft().collect::>(); + let data_car = tree.root().0.cardinality(); + assert_eq!(data_car + data_car - 1, subtree.len()); + + for c in subtree { + let res = tree.find(c); + assert!(res.is_some()); + + let (depth, index, a, b) = res.unwrap(); + let c_found = tree.get(depth, index); + assert!(c_found.is_some()); + + let c_found = c_found.unwrap(); + check_ball_eq(c, c_found); + + let children = tree.children_of(depth, index); + if a == b { + assert!(children.is_empty()); + } else { + assert_eq!(b - a, children.len()); + + let car_sum = children.iter().map(|(c, _, _)| c.cardinality()).sum(); + assert_eq!(c.cardinality(), car_sum); + + for (child, _, _) in children { + assert!(!child.indices().is_empty()); + assert!(child.is_descendant_of(c)); + } + } + } +} diff --git a/crates/abd-clam/tests/utils.rs b/crates/abd-clam/tests/utils.rs new file mode 100644 index 000000000..66abf0b07 --- /dev/null +++ b/crates/abd-clam/tests/utils.rs @@ -0,0 +1,180 @@ +//! Tests for the utility functions in `abd-clam`. + +use abd_clam::utils; +use rand::prelude::*; +use symagen::random_data; + +#[test] +fn test_transpose() { + // Input data: 3 rows x 6 columns + let data: Vec<[f64; 6]> = vec![ + [2.0, 3.0, 5.0, 7.0, 11.0, 13.0], + [4.0, 3.0, 5.0, 9.0, 10.0, 15.0], + [6.0, 2.0, 8.0, 11.0, 9.0, 11.0], + ]; + + // Expected transposed data: 6 rows x 3 columns + let expected_transposed: [Vec; 6] = [ + vec![2.0, 4.0, 6.0], + vec![3.0, 3.0, 2.0], + vec![5.0, 5.0, 8.0], + vec![7.0, 9.0, 11.0], + vec![11.0, 10.0, 9.0], + vec![13.0, 15.0, 11.0], + ]; + + let transposed_data = utils::rows_to_cols(&data); + + // Check if the transposed data matches the expected result + for i in 0..6 { + assert_eq!(transposed_data[i], expected_transposed[i]); + } +} + +#[test] +fn test_means() { + let all_ratios: Vec<[f64; 6]> = vec![ + [2.0, 4.0, 5.0, 6.0, 9.0, 15.0], + [3.0, 3.0, 6.0, 4.0, 7.0, 10.0], + [5.0, 5.0, 8.0, 8.0, 8.0, 1.0], + ]; + + let transposed = utils::rows_to_cols(&all_ratios); + let means = utils::calc_row_means(&transposed); + + let expected_means: [f64; 6] = [ + 3.333_333_333_333_333_5, + 4.0, + 6.333_333_333_333_334, + 6.0, + 8.0, + 8.666_666_666_666_668, + ]; + + means + .iter() + .zip(expected_means.iter()) + .for_each(|(&a, &b)| assert!(float_cmp::approx_eq!(f64, a, b, ulps = 2), "{a}, {b} not equal")); +} + +#[test] +fn test_sds() { + let all_ratios: Vec<[f64; 6]> = vec![ + [2.0, 4.0, 5.0, 6.0, 9.0, 15.0], + [3.0, 3.0, 6.0, 4.0, 7.0, 10.0], + [5.0, 5.0, 8.0, 8.0, 8.0, 1.0], + ]; + + let expected_standard_deviations: [f64; 6] = [ + 1.247_219_128_924_6, + 0.816_496_580_927_73, + 1.247_219_128_924_6, + 1.632_993_161_855_5, + 0.816_496_580_927_73, + 5.792_715_732_327_6, + ]; + let sds = utils::calc_row_sds(&utils::rows_to_cols(&all_ratios)); + + sds.iter() + .zip(expected_standard_deviations.iter()) + .for_each(|(&a, &b)| { + assert!( + float_cmp::approx_eq!(f64, a, b, epsilon = 0.000_000_03), + "{a}, {b} not equal" + ); + }); +} + +#[test] +fn test_mean_variance() { + // Some synthetic cases to test edge results + let mut test_cases: Vec> = vec![ + vec![0.0], + vec![0.0, 0.0], + vec![1.0], + vec![1.0, 2.0], + vec![0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25], + ]; + + // Use cardinalities of 1, 2, 1000, 100_000 and then 1_000_000 - 10_000_000 in steps of 1_000_000 + let cardinalities = vec![1, 2, 1_000, 100_000] + .into_iter() + .chain((1..=10).map(|i| i * 1_000_000)) + .collect::>(); + + // Ranges for the values generated by SyMaGen + let ranges = vec![ + (-100_000., 0.), + (-10_000., 0.), + (-1_000., 0.), + (0., 1_000.), + (0., 10_000.), + (0., 100_000.), + // These ranges cause the test to fail due to floating point accuracy issues when the sign switches + //(-1_000., 1_000.), + //(-10_000., 10_000.), + //(-100_000., 100_000.) + ]; + + let dimensionality = 1; + let seed = 42; + + // Generate random data for each cardinality and min/max value where max_val > min_val + for (cardinality, (min_val, max_val)) in cardinalities.into_iter().zip(ranges.into_iter()) { + let data = random_data::random_tabular( + dimensionality, + cardinality, + min_val, + max_val, + &mut rand::rngs::StdRng::seed_from_u64(seed), + ) + .into_iter() + .flatten() + .collect::>(); + test_cases.push(data); + } + + let (actual_means, actual_variances): (Vec, Vec) = test_cases + .iter() + .map(|values| utils::mean_variance::(values)) + .unzip(); + + // Calculate expected_means and expected_variances using + // statistical::mean and statistical::population_variance + let expected_means: Vec = test_cases.iter().map(|values| statistical::mean(values)).collect(); + let expected_variances: Vec = test_cases + .iter() + .zip(expected_means.iter()) + .map(|(values, &mean)| statistical::population_variance(values, Some(mean))) + .collect(); + + actual_means.iter().zip(expected_means.iter()).for_each(|(&a, &b)| { + assert!( + float_cmp::approx_eq!(f64, a, b, ulps = 2), + "Means not equal. Actual: {}. Expected: {}. Difference: {}.", + a, + b, + a - b + ); + }); + + actual_variances + .iter() + .zip(expected_variances.iter()) + .for_each(|(&a, &b)| { + assert!( + float_cmp::approx_eq!(f64, a, b, epsilon = 3e-3), + "Variances not equal. Actual: {}. Expected: {}. Difference: {}.", + a, + b, + a - b + ); + }); +} + +#[test] +fn test_standard_deviation() { + let data = [2., 4., 4., 4., 5., 5., 7., 9.]; + let std = utils::standard_deviation::(&data); + assert!((std - 2.).abs() < 1e-6); +} diff --git a/crates/distances/Cargo.toml b/crates/distances/Cargo.toml index 211a4825e..947f5ecac 100644 --- a/crates/distances/Cargo.toml +++ b/crates/distances/Cargo.toml @@ -20,6 +20,7 @@ publish = true rand = { workspace = true } serde = { workspace = true } libm = { workspace = true } +num-integer = { workspace = true } [dev-dependencies] symagen = { workspace = true } diff --git a/crates/distances/src/number/_bool.rs b/crates/distances/src/number/_bool.rs index 95a45568d..aba752b89 100644 --- a/crates/distances/src/number/_bool.rs +++ b/crates/distances/src/number/_bool.rs @@ -198,4 +198,8 @@ impl Number for Bool { fn next_random(rng: &mut R) -> Self { Self(rng.gen()) } + + fn total_cmp(&self, other: &Self) -> core::cmp::Ordering { + self.0.total_cmp(&other.0) + } } diff --git a/crates/distances/src/number/_number.rs b/crates/distances/src/number/_number.rs index 8577f6258..f4bbe3bd8 100644 --- a/crates/distances/src/number/_number.rs +++ b/crates/distances/src/number/_number.rs @@ -70,6 +70,18 @@ pub trait Number: /// Returns the number as a `f64`. This may be a lossy conversion. fn as_f64(self) -> f64; + /// Returns the number as a `usize`. This may be a lossy conversion. + #[allow(clippy::cast_possible_truncation)] + fn as_usize(self) -> usize { + self.as_u64() as usize + } + + /// Returns the number as a `isize`. This may be a lossy conversion. + #[allow(clippy::cast_possible_truncation)] + fn as_isize(self) -> isize { + self.as_i64() as isize + } + /// Returns the number as a `u64`. This may be a lossy conversion. fn as_u64(self) -> u64; @@ -109,6 +121,29 @@ pub trait Number: /// Returns a random `Number`. fn next_random(rng: &mut R) -> Self; + + /// Returns a total ordering of the number. + fn total_cmp(&self, other: &Self) -> core::cmp::Ordering; + + /// Returns the smaller of two numbers. + #[must_use] + fn min(self, other: Self) -> Self { + if self < other { + self + } else { + other + } + } + + /// Returns the larger of two numbers. + #[must_use] + fn max(self, other: Self) -> Self { + if self > other { + self + } else { + other + } + } } impl Number for f32 { @@ -173,6 +208,10 @@ impl Number for f32 { fn next_random(rng: &mut R) -> Self { rng.gen() } + + fn total_cmp(&self, other: &Self) -> core::cmp::Ordering { + self.total_cmp(other) + } } impl Number for f64 { @@ -237,6 +276,10 @@ impl Number for f64 { fn next_random(rng: &mut R) -> Self { rng.gen() } + + fn total_cmp(&self, other: &Self) -> core::cmp::Ordering { + self.total_cmp(other) + } } /// A macro to implement the `Number` trait for primitive types. @@ -301,6 +344,10 @@ macro_rules! impl_number_iint { fn next_random(rng: &mut R) -> Self { rng.gen() } + + fn total_cmp(&self, other: &Self) -> core::cmp::Ordering { + self.cmp(other) + } } )* } @@ -372,6 +419,10 @@ macro_rules! impl_number_uint { fn next_random(rng: &mut R) -> Self { rng.gen() } + + fn total_cmp(&self, other: &Self) -> core::cmp::Ordering { + self.cmp(other) + } } )* } diff --git a/crates/distances/src/number/_variants.rs b/crates/distances/src/number/_variants.rs index 0e34715b1..8dcf7be1c 100644 --- a/crates/distances/src/number/_variants.rs +++ b/crates/distances/src/number/_variants.rs @@ -1,17 +1,35 @@ //! Number variants for floats, integers, and unsigned integers. -use core::hash::Hash; +use core::{hash::Hash, ops::Neg}; + +use num_integer::Integer; use crate::Number; /// Sub-trait of `Number` for all integer types. -pub trait Int: Number + Hash + Eq + Ord {} +pub trait Int: Number + Hash + Eq + Ord { + /// Returns the Greatest Common Divisor of two integers. + #[must_use] + fn gcd(&self, other: &Self) -> Self; + + /// Returns the Least Common Multiple of two integers. + #[must_use] + fn lcm(&self, other: &Self) -> Self; +} /// Macro to implement `IntNumber` for all integer types. macro_rules! impl_int { ($($ty:ty),*) => { $( - impl Int for $ty {} + impl Int for $ty { + fn gcd(&self, other: &Self) -> Self { + Integer::gcd(&self, other) + } + + fn lcm(&self, other: &Self) -> Self { + Integer::lcm(&self, other) + } + } )* } } @@ -19,7 +37,7 @@ macro_rules! impl_int { impl_int!(u8, i8, u16, i16, u32, i32, u64, i64, u128, i128, usize, isize); /// Sub-trait of `Number` for all signed integer types. -pub trait IInt: Number + Hash + Eq + Ord {} +pub trait IInt: Number + Neg + Hash + Eq + Ord {} /// Macro to implement `IIntNumber` for all signed integer types. macro_rules! impl_iint { diff --git a/crates/distances/src/strings/mod.rs b/crates/distances/src/strings/mod.rs index 2aa156c9e..69841dc06 100644 --- a/crates/distances/src/strings/mod.rs +++ b/crates/distances/src/strings/mod.rs @@ -5,7 +5,8 @@ pub mod needleman_wunsch; use crate::number::UInt; pub use needleman_wunsch::{ - _x_to_y, aligned_x_to_y, aligned_x_to_y_no_sub, apply_edits, nw_distance, unaligned_x_to_y, x_to_y_alignment, Edit, + aligned_x_to_y, aligned_x_to_y_no_sub, apply_edits, nw_distance, unaligned_x_to_y, x2y_helper, x_to_y_alignment, + Edit, }; /// Penalties to use in the Needleman-Wunsch distance calculation. @@ -76,9 +77,9 @@ pub fn levenshtein_custom(penalties: Penalties) -> impl Fn(&str, &st U::from(x.len()) } else if x.len() < y.len() { // require tat a is no shorter than b - _levenshtein(y, x, penalties) + lev_helper(y, x, penalties) } else { - _levenshtein(x, y, penalties) + lev_helper(x, y, penalties) } } } @@ -139,9 +140,9 @@ pub fn levenshtein(x: &str, y: &str) -> U { U::from(x.len()) } else if x.len() < y.len() { // require tat a is no shorter than b - _levenshtein(y, x, Penalties::default()) + lev_helper(y, x, Penalties::default()) } else { - _levenshtein(x, y, Penalties::default()) + lev_helper(x, y, Penalties::default()) } } @@ -149,7 +150,7 @@ pub fn levenshtein(x: &str, y: &str) -> U { /// This function actually performs the dynamic programming for the /// Levenshtein edit distance, using the `penalties` struct. #[allow(unused_variables)] -fn _levenshtein(x: &str, y: &str, penalties: Penalties) -> U { +fn lev_helper(x: &str, y: &str, penalties: Penalties) -> U { // initialize DP table for string y // this is a bit ugly with the U casts let mut cur = (0..=y.len()).map(U::from).collect::>(); diff --git a/crates/distances/src/strings/needleman_wunsch/helpers.rs b/crates/distances/src/strings/needleman_wunsch/helpers.rs index 1966c87ba..58ddc0338 100644 --- a/crates/distances/src/strings/needleman_wunsch/helpers.rs +++ b/crates/distances/src/strings/needleman_wunsch/helpers.rs @@ -162,7 +162,7 @@ pub fn trace_back_iterative(table: &[Vec<(U, Direction)>], [x, y]: [&st pub fn trace_back_recursive(table: &[Vec<(U, Direction)>], [x, y]: [&str; 2]) -> (String, String) { let (mut aligned_x, mut aligned_y) = (Vec::new(), Vec::new()); - _trace_back_recursive( + trb_helper( table, [y.len(), x.len()], [x.as_bytes(), y.as_bytes()], @@ -187,7 +187,7 @@ pub fn trace_back_recursive(table: &[Vec<(U, Direction)>], [x, y]: [&st /// * `[x, y]`: The two sequences to align, passed as slices of bytes. /// * `[aligned_x, aligned_y]`: mutable aligned sequences that will be built /// up from initially empty vectors. -fn _trace_back_recursive( +fn trb_helper( table: &[Vec<(U, Direction)>], [mut row_i, mut col_i]: [usize; 2], [x, y]: [&[u8]; 2], @@ -212,7 +212,7 @@ fn _trace_back_recursive( row_i -= 1; } }; - _trace_back_recursive(table, [row_i, col_i], [x, y], [aligned_x, aligned_y]); + trb_helper(table, [row_i, col_i], [x, y], [aligned_x, aligned_y]); } } @@ -225,12 +225,12 @@ fn _trace_back_recursive( /// /// # Returns /// -/// A 2-slice of Vec, each containing the edits needed to convert one aligned +/// A 2-slice of `Vec`, each containing the edits needed to convert one aligned /// sequence into the other. /// Since both input sequences are aligned, all edits are substitutions in the returned vectors are Substitutions. #[must_use] pub fn compute_edits(x: &str, y: &str) -> [Vec; 2] { - [_x_to_y(x, y), _x_to_y(y, x)] + [x2y_helper(x, y), x2y_helper(y, x)] } /// Helper for `compute_edits` to compute the edits for turning aligned `x` into aligned `y`. @@ -247,7 +247,7 @@ pub fn compute_edits(x: &str, y: &str) -> [Vec; 2] { /// /// A vector of edits needed to convert `x` into `y`. #[must_use] -pub fn _x_to_y(x: &str, y: &str) -> Vec { +pub fn x2y_helper(x: &str, y: &str) -> Vec { x.chars() .zip(y.chars()) .enumerate() diff --git a/crates/distances/src/strings/needleman_wunsch/mod.rs b/crates/distances/src/strings/needleman_wunsch/mod.rs index 2e6e2e07f..ed36fef90 100644 --- a/crates/distances/src/strings/needleman_wunsch/mod.rs +++ b/crates/distances/src/strings/needleman_wunsch/mod.rs @@ -9,8 +9,8 @@ use crate::number::UInt; use super::Penalties; pub use helpers::{ - _x_to_y, aligned_x_to_y, aligned_x_to_y_no_sub, apply_edits, compute_edits, compute_table, trace_back_iterative, - trace_back_recursive, unaligned_x_to_y, x_to_y_alignment, Edit, + aligned_x_to_y, aligned_x_to_y_no_sub, apply_edits, compute_edits, compute_table, trace_back_iterative, + trace_back_recursive, unaligned_x_to_y, x2y_helper, x_to_y_alignment, Edit, }; /// Use a custom set of penalties to create a function to that calculates the diff --git a/crates/distances/src/vectors/lp_norms.rs b/crates/distances/src/vectors/lp_norms.rs index be3356c6b..bfc01acfd 100644 --- a/crates/distances/src/vectors/lp_norms.rs +++ b/crates/distances/src/vectors/lp_norms.rs @@ -1,7 +1,5 @@ //! Provides functions for calculating Lp-norms between two vectors. -use core::cmp::Ordering; - use crate::{number::Float, Number}; use super::utils::abs_diff_iter; @@ -186,9 +184,7 @@ pub fn l4_norm(x: &[T], y: &[T]) -> U { /// assert!((distance - 5.0).abs() <= f64::EPSILON); /// ``` pub fn chebyshev(x: &[T], y: &[T]) -> T { - abs_diff_iter(x, y) - .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Less)) - .unwrap_or(T::ZERO) + abs_diff_iter(x, y).max_by(Number::total_cmp).unwrap_or(T::ZERO) } /// General (Lp-norm)^p between two vectors. diff --git a/crates/distances/tests/test_edits.rs b/crates/distances/tests/test_edits.rs index 105d38555..ccd78fa9f 100644 --- a/crates/distances/tests/test_edits.rs +++ b/crates/distances/tests/test_edits.rs @@ -1,5 +1,5 @@ use distances::strings::{ - Edit, _x_to_y, aligned_x_to_y, aligned_x_to_y_no_sub, apply_edits, unaligned_x_to_y, x_to_y_alignment, + aligned_x_to_y, aligned_x_to_y_no_sub, apply_edits, unaligned_x_to_y, x2y_helper, x_to_y_alignment, Edit, }; #[test] @@ -7,7 +7,7 @@ fn tiny_aligned() { let x = "A-C"; let y = "AAC"; - let actual = _x_to_y(x, y); + let actual = x2y_helper(x, y); let expected = vec![Edit::Sub(1, 'A')]; assert_eq!(actual, expected); @@ -18,7 +18,7 @@ fn medium_aligned() { let x = "NAJIB-PEPPERSEATS"; let y = "NAJIBEATSPEPPE-RS"; - let actual = _x_to_y(x, y); + let actual = x2y_helper(x, y); let expected = vec![ Edit::Sub(5, 'E'), Edit::Sub(6, 'A'), diff --git a/crates/distances/tests/test_sets.rs b/crates/distances/tests/test_sets.rs index 22d01e7c9..6033bbd71 100644 --- a/crates/distances/tests/test_sets.rs +++ b/crates/distances/tests/test_sets.rs @@ -100,21 +100,21 @@ fn bounds_test() { distance = jaccard(&x, &x); assert!(distance < f32::EPSILON); distance = jaccard(&x, &y); - assert!(distance - 1.0 < f32::EPSILON); + assert!((distance - 1.0).abs() < f32::EPSILON); distance = jaccard(&y, &y); - assert!(distance - 1.0 < f32::EPSILON); + assert!((distance - 1.0).abs() < f32::EPSILON); distance = dice(&x, &x); assert!(distance < f32::EPSILON); distance = dice(&x, &y); - assert!(distance - 1.0 < f32::EPSILON); + assert!((distance - 1.0).abs() < f32::EPSILON); distance = dice(&y, &y); - assert!(distance - 1.0 < f32::EPSILON); + assert!((distance - 1.0).abs() < f32::EPSILON); distance = kulsinski(&x, &x); assert!(distance < f32::EPSILON); distance = kulsinski(&x, &y); - assert!(distance - 1.0 < f32::EPSILON); + assert!((distance - 1.0).abs() < f32::EPSILON); distance = kulsinski(&y, &y); - assert!(distance - 1.0 < f32::EPSILON); + assert!((distance - 1.0).abs() < f32::EPSILON); } diff --git a/crates/distances/tests/test_vectors_f32.rs b/crates/distances/tests/test_vectors_f32.rs index 018984f92..e269a660b 100644 --- a/crates/distances/tests/test_vectors_f32.rs +++ b/crates/distances/tests/test_vectors_f32.rs @@ -1,5 +1,3 @@ -use core::f32::EPSILON; - use rand::prelude::*; use symagen::random_data; @@ -55,7 +53,7 @@ fn lp_f32() { let e_l1 = l1(x, y); let a_l1: f32 = manhattan(x, y); assert!( - (e_l1 - a_l1).abs() <= EPSILON, + (e_l1 - a_l1).abs() <= f32::EPSILON, "Manhattan: expected: {}, actual: {}", e_l1, a_l1 @@ -64,7 +62,7 @@ fn lp_f32() { let expected = l2_sq(x, y); let actual: f32 = euclidean_sq(x, y); assert!( - (expected - actual).abs() <= EPSILON, + (expected - actual).abs() <= f32::EPSILON, "Euclidean squared: expected: {}, actual: {}", expected, actual @@ -73,7 +71,7 @@ fn lp_f32() { let expected = l2(x, y); let actual: f32 = euclidean(x, y); assert!( - (expected - actual).abs() <= EPSILON, + (expected - actual).abs() <= f32::EPSILON, "Euclidean: expected: {}, actual: {}", expected, actual @@ -82,7 +80,7 @@ fn lp_f32() { let e_l3 = l3(x, y); let a_l3: f32 = l3_norm(x, y); assert!( - (e_l3 - a_l3).abs() <= EPSILON, + (e_l3 - a_l3).abs() <= f32::EPSILON, "L3 norm: expected: {}, actual: {}", e_l3, a_l3 @@ -91,7 +89,7 @@ fn lp_f32() { let e_l4 = l4(x, y); let a_l4: f32 = l4_norm(x, y); assert!( - (e_l4 - a_l4).abs() <= EPSILON, + (e_l4 - a_l4).abs() <= f32::EPSILON, "L4 norm: expected: {}, actual: {}", e_l4, a_l4 @@ -100,7 +98,7 @@ fn lp_f32() { let e_l_inf = l_inf(x, y); let a_l_inf: f32 = chebyshev(x, y); assert!( - (e_l_inf - a_l_inf).abs() <= EPSILON, + (e_l_inf - a_l_inf).abs() <= f32::EPSILON, "Chebyshev: expected: {}, actual: {}", e_l_inf, a_l_inf diff --git a/crates/distances/tests/test_vectors_u32.rs b/crates/distances/tests/test_vectors_u32.rs index 2d2854c66..e57858ef3 100644 --- a/crates/distances/tests/test_vectors_u32.rs +++ b/crates/distances/tests/test_vectors_u32.rs @@ -1,5 +1,3 @@ -use core::f32::EPSILON; - use rand::prelude::*; use symagen::random_data; @@ -71,28 +69,28 @@ fn lp_u32() { let e_l2s = l2_sq(x, y); let a_l2s: f32 = euclidean_sq(x, y); assert!( - (e_l2s - a_l2s).abs() <= EPSILON, + (e_l2s - a_l2s).abs() <= f32::EPSILON, "Euclidean squared: expected: {e_l2s}, actual: {a_l2s}" ); let e_l2 = l2(x, y); let a_l2: f32 = euclidean(x, y); assert!( - (e_l2 - a_l2).abs() <= EPSILON, + (e_l2 - a_l2).abs() <= f32::EPSILON, "Euclidean: expected: {e_l2}, actual: {a_l2}" ); let e_l3 = l3(x, y); let a_l3: f32 = l3_norm(x, y); assert!( - (e_l3 - a_l3).abs() <= EPSILON, + (e_l3 - a_l3).abs() <= f32::EPSILON, "L3 norm: expected: {e_l3}, actual: {a_l3}" ); let e_l4 = l4(x, y); let a_l4: f32 = l4_norm(x, y); assert!( - (e_l4 - a_l4).abs() <= EPSILON, + (e_l4 - a_l4).abs() <= f32::EPSILON, "L4 norm: expected: {e_l4}, actual: {a_l4}" ); diff --git a/crates/results/cakes/Cargo.toml b/crates/results/cakes/Cargo.toml index 08cdff5c4..6f97f38f2 100644 --- a/crates/results/cakes/Cargo.toml +++ b/crates/results/cakes/Cargo.toml @@ -8,13 +8,15 @@ clap = { version = "4.5.16", features = ["derive"] } ftlog = { version = "0.2.0" } distances = { workspace = true } stringzilla = "3.9.5" -bio = "2.0.1" +bio = { workspace = true } rand = { workspace = true } -abd-clam = { workspace = true, features = ["csv"] } +abd-clam = { workspace = true, features = ["disk-io"] } hdf5 = { package = "hdf5-metno", version = "0.9.0" } serde = { workspace = true } -bincode = { workspace = true } serde_json = "1.0" +rayon = { workspace = true } +bitcode = { workspace = true } +bench-utils = { path = "../../../benches/utils" } # ndarray = { workspace = true } diff --git a/crates/results/cakes/src/data/mod.rs b/crates/results/cakes/src/data/mod.rs index c06d1e692..104bbdd53 100644 --- a/crates/results/cakes/src/data/mod.rs +++ b/crates/results/cakes/src/data/mod.rs @@ -4,4 +4,8 @@ mod raw; mod tree; #[allow(clippy::module_name_repetitions)] -pub use raw::RawData; +#[allow(unused_imports)] +pub use raw::{fasta, RawData}; + +#[allow(unused_imports)] +pub use tree::PathManager; diff --git a/crates/results/cakes/src/data/raw/fasta.rs b/crates/results/cakes/src/data/raw/fasta.rs index e50d44667..a6e23ef3d 100644 --- a/crates/results/cakes/src/data/raw/fasta.rs +++ b/crates/results/cakes/src/data/raw/fasta.rs @@ -1,12 +1,7 @@ //! Reading FASTA files -use std::path::Path; - +use abd_clam::{dataset::AssociatesMetadata, Dataset}; use clap::error::Result; -use rand::seq::SliceRandom; - -/// A collection of named sequences. -type NamedSequences = Vec<(String, String)>; /// Reads a FASTA file from the given path. /// @@ -14,6 +9,7 @@ type NamedSequences = Vec<(String, String)>; /// /// * `path`: The path to the FASTA file. /// * `holdout`: The number of sequences to hold out for queries. +/// * `remove_gaps`: Whether to remove gaps from the sequences. /// /// # Returns /// @@ -26,48 +22,20 @@ type NamedSequences = Vec<(String, String)>; /// * If the extension is not `.fasta`. /// * If the file cannot be read as a FASTA file. /// * If any ID or sequence is empty. -pub fn read>(path: &P, holdout: usize) -> Result<([NamedSequences; 2], [usize; 2]), String> { - let path = path.as_ref(); - if !path.exists() { - return Err(format!("Path {path:?} does not exist!")); - } - - if !path.extension().map_or(false, |ext| ext == "fasta") { - return Err(format!("Path {path:?} does not have the `.fasta` extension!")); - } - - ftlog::info!("Reading FASTA file from {path:?}."); - - let mut records = bio::io::fasta::Reader::from_file(path) - .map_err(|e| e.to_string())? - .records(); - - let mut seqs = Vec::new(); - let (mut min_len, mut max_len) = (usize::MAX, 0); - - while let Some(Ok(record)) = records.next() { - let name = record.id().to_string(); - if name.is_empty() { - return Err(format!("Empty ID for record {}.", seqs.len())); - } - - let seq = record.seq().iter().map(|&b| b as char).collect::(); - if seq.is_empty() { - return Err(format!("Empty sequence for record {} with ID {name}.", seqs.len())); - } - - min_len = min_len.min(seq.len()); - max_len = max_len.max(seq.len()); - - seqs.push((name, seq)); - } - - ftlog::info!("Read {} sequences from {path:?}.", seqs.len()); - ftlog::info!("Minimum length: {min_len}, Maximum length: {max_len}."); - - // Shuffle the sequences and hold out a query set. - seqs.shuffle(&mut rand::thread_rng()); - let queries = seqs.split_off(seqs.len() - holdout); +#[allow(clippy::type_complexity)] +pub fn read>( + path: &P, + holdout: usize, + remove_gaps: bool, +) -> Result<([Vec<(String, String)>; 2], [usize; 2]), String> { + let (data, queries) = bench_utils::fasta::read(path, holdout, remove_gaps)?; + let (min_len, max_len) = data.dimensionality_hint(); + let max_len = max_len.unwrap_or(min_len); + + let seqs = data.items(); + let ids = data.metadata(); + + let seqs = ids.iter().cloned().zip(seqs.iter().cloned()).collect(); #[allow(clippy::tuple_array_conversions)] Ok(([seqs, queries], [min_len, max_len])) diff --git a/crates/results/cakes/src/data/raw/mod.rs b/crates/results/cakes/src/data/raw/mod.rs index 7c14d9dc2..2c757079e 100644 --- a/crates/results/cakes/src/data/raw/mod.rs +++ b/crates/results/cakes/src/data/raw/mod.rs @@ -1,13 +1,11 @@ //! Reading data from various sources. -use std::fs::File; - -use abd_clam::MetricSpace; +use abd_clam::{dataset::AssociatesMetadataMut, FlatVec}; use super::tree::instances::{Aligned, MemberSet, Unaligned}; mod ann_benchmarks; -mod fasta; +pub mod fasta; /// The datasets we use for benchmarks. #[derive(clap::ValueEnum, Debug, Clone)] @@ -42,6 +40,7 @@ pub enum RawData { impl RawData { /// Returns the name of the dataset as a string. + #[must_use] pub const fn name(&self) -> &str { match self { Self::GreenGenes_12_10 => "gg_12_10", @@ -89,23 +88,23 @@ impl RawData { match self { Self::GreenGenes_12_10 | Self::GreenGenes_13_5 | Self::Silva_18S | Self::PdbSeq => { - let (mut data, queries) = if data_path.exists() && queries_path.exists() { + let (data, queries) = if data_path.exists() && queries_path.exists() { ftlog::info!("Reading data from {data_path:?}"); - let data = bincode::deserialize_from(File::open(&data_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let bytes: Vec = std::fs::read(&data_path).map_err(|e| e.to_string())?; + let data = bitcode::decode(&bytes).map_err(|e| e.to_string())?; ftlog::info!("Reading queries from {queries_path:?}"); - let queries = bincode::deserialize_from(File::open(&queries_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let bytes: Vec = std::fs::read(&queries_path).map_err(|e| e.to_string())?; + let queries = bitcode::decode(&bytes).map_err(|e| e.to_string())?; (data, queries) } else { - let ([data, queries], [min_len, max_len]) = fasta::read(inp_path, 1000)?; + let ([data, queries], [min_len, max_len]) = fasta::read(inp_path, 1000, false)?; let (metadata, data): (Vec<_>, Vec<_>) = data.into_iter().map(|(name, seq)| (name, Unaligned::from(seq))).unzip(); - let data = abd_clam::FlatVec::new(data, Unaligned::metric())? - .with_metadata(metadata)? + let data = FlatVec::new(data)? + .with_metadata(&metadata)? .with_dim_lower_bound(min_len) .with_dim_upper_bound(max_len); @@ -115,38 +114,36 @@ impl RawData { .collect(); ftlog::info!("Writing data to {data_path:?}"); - bincode::serialize_into(File::create(&data_path).map_err(|e| e.to_string())?, &data) - .map_err(|e| e.to_string())?; + let bytes = bitcode::encode(&data).map_err(|e| e.to_string())?; + std::fs::write(&data_path, &bytes).map_err(|e| e.to_string())?; ftlog::info!("Writing queries to {queries_path:?}"); - bincode::serialize_into(File::create(&queries_path).map_err(|e| e.to_string())?, &queries) - .map_err(|e| e.to_string())?; + let bytes = bitcode::encode(&queries).map_err(|e| e.to_string())?; + std::fs::write(&queries_path, &bytes).map_err(|e| e.to_string())?; (data, queries) }; - // Set the metric for the data, incase it was deserialized. - data.set_metric(Unaligned::metric()); super::tree::Tree::new_unaligned(self.name(), out_dir, data, queries) } Self::GreenGenesAligned_12_10 | Self::SilvaAligned_18S => { - let (mut data, queries) = if data_path.exists() && queries_path.exists() { + let (data, queries) = if data_path.exists() && queries_path.exists() { ftlog::info!("Reading data from {data_path:?}"); - let data = bincode::deserialize_from(File::open(&data_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let bytes: Vec = std::fs::read(&data_path).map_err(|e| e.to_string())?; + let data = bitcode::decode(&bytes).map_err(|e| e.to_string())?; ftlog::info!("Reading queries from {queries_path:?}"); - let queries = bincode::deserialize_from(File::open(&queries_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let bytes: Vec = std::fs::read(&queries_path).map_err(|e| e.to_string())?; + let queries = bitcode::decode(&bytes).map_err(|e| e.to_string())?; (data, queries) } else { - let ([data, queries], [min_len, max_len]) = fasta::read(inp_path, 1000)?; + let ([data, queries], [min_len, max_len]) = fasta::read(inp_path, 1000, false)?; let (metadata, data): (Vec<_>, Vec<_>) = data.into_iter().map(|(name, seq)| (name, Aligned::from(seq))).unzip(); - let data = abd_clam::FlatVec::new(data, Aligned::metric())? - .with_metadata(metadata)? + let data = FlatVec::new(data)? + .with_metadata(&metadata)? .with_dim_lower_bound(min_len) .with_dim_upper_bound(max_len); @@ -156,62 +153,56 @@ impl RawData { .collect(); ftlog::info!("Writing data to {data_path:?}"); - bincode::serialize_into(File::create(&data_path).map_err(|e| e.to_string())?, &data) - .map_err(|e| e.to_string())?; + let bytes = bitcode::encode(&data).map_err(|e| e.to_string())?; + std::fs::write(&data_path, &bytes).map_err(|e| e.to_string())?; ftlog::info!("Writing queries to {queries_path:?}"); - bincode::serialize_into(File::create(&queries_path).map_err(|e| e.to_string())?, &queries) - .map_err(|e| e.to_string())?; + let bytes = bitcode::encode(&queries).map_err(|e| e.to_string())?; + std::fs::write(&queries_path, &bytes).map_err(|e| e.to_string())?; (data, queries) }; - // Set the metric for the data, incase it was deserialized. - data.set_metric(Aligned::metric()); super::tree::Tree::new_aligned(self.name(), out_dir, data, queries) } Self::Kosarak | Self::MovieLens_10M => { - let (mut data, queries, ground_truth) = - if data_path.exists() && queries_path.exists() && gt_path.exists() { - ftlog::info!("Reading data from {data_path:?}"); - let data = bincode::deserialize_from(File::open(&data_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; - - ftlog::info!("Reading queries from {queries_path:?}"); - let queries = bincode::deserialize_from(File::open(&queries_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; - - ftlog::info!("Reading ground truth from {gt_path:?}"); - let ground_truth = bincode::deserialize_from(File::open(>_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; - - (data, queries, ground_truth) - } else { - let data = ann_benchmarks::read::<_, usize>(inp_path, true)?; - let (data, queries, ground_truth) = (data.train, data.queries, data.neighbors); - - let metric = MemberSet::metric(); - let data = data.iter().map(MemberSet::<_, f32>::from).collect(); - let data = abd_clam::FlatVec::new(data, metric)?; - - let queries = queries.iter().map(MemberSet::<_, f32>::from).enumerate().collect(); - - ftlog::info!("Writing data to {data_path:?}"); - bincode::serialize_into(File::create(&data_path).map_err(|e| e.to_string())?, &data) - .map_err(|e| e.to_string())?; - - ftlog::info!("Writing queries to {queries_path:?}"); - bincode::serialize_into(File::create(&queries_path).map_err(|e| e.to_string())?, &queries) - .map_err(|e| e.to_string())?; - - ftlog::info!("Writing ground truth to {gt_path:?}"); - bincode::serialize_into(File::create(>_path).map_err(|e| e.to_string())?, &ground_truth) - .map_err(|e| e.to_string())?; - - (data, queries, ground_truth) - }; - // Set the metric for the data, incase it was deserialized. - data.set_metric(MemberSet::metric()); + let (data, queries, ground_truth) = if data_path.exists() && queries_path.exists() && gt_path.exists() { + ftlog::info!("Reading data from {data_path:?}"); + let bytes = std::fs::read(&data_path).map_err(|e| e.to_string())?; + let data = bitcode::decode(&bytes).map_err(|e| e.to_string())?; + + ftlog::info!("Reading queries from {queries_path:?}"); + let bytes = std::fs::read(&queries_path).map_err(|e| e.to_string())?; + let queries = bitcode::decode(&bytes).map_err(|e| e.to_string())?; + + ftlog::info!("Reading ground truth from {gt_path:?}"); + let bytes = std::fs::read(>_path).map_err(|e| e.to_string())?; + let ground_truth = bitcode::decode(&bytes).map_err(|e| e.to_string())?; + + (data, queries, ground_truth) + } else { + let data = ann_benchmarks::read::<_, usize>(inp_path, true)?; + let (data, queries, ground_truth) = (data.train, data.queries, data.neighbors); + + let data = data.iter().map(MemberSet::<_>::from).collect(); + let data = FlatVec::new(data)?; + + let queries = queries.iter().map(MemberSet::<_>::from).enumerate().collect(); + + ftlog::info!("Writing data to {data_path:?}"); + let bytes = bitcode::encode(&data).map_err(|e| e.to_string())?; + std::fs::write(&data_path, &bytes).map_err(|e| e.to_string())?; + + ftlog::info!("Writing queries to {queries_path:?}"); + let bytes = bitcode::encode(&queries).map_err(|e| e.to_string())?; + std::fs::write(&queries_path, &bytes).map_err(|e| e.to_string())?; + + ftlog::info!("Writing ground truth to {gt_path:?}"); + let bytes = bitcode::encode(&ground_truth).map_err(|e| e.to_string())?; + std::fs::write(>_path, &bytes).map_err(|e| e.to_string())?; + + (data, queries, ground_truth) + }; super::tree::Tree::new_ann_set(self.name(), out_dir, data, queries, ground_truth) } diff --git a/crates/results/cakes/src/data/tree/aligned.rs b/crates/results/cakes/src/data/tree/aligned.rs index 8c2f1f90e..f9dd85da7 100644 --- a/crates/results/cakes/src/data/tree/aligned.rs +++ b/crates/results/cakes/src/data/tree/aligned.rs @@ -3,23 +3,26 @@ use std::collections::HashMap; use abd_clam::{ - adapter::{ParAdapter, ParBallAdapter}, - cakes::{Algorithm, CodecData, Decompressible, OffBall, SquishyBall}, - cluster::WriteCsv, - partition::ParPartition, - Ball, Cluster, Dataset, FlatVec, MetricSpace, + cakes::{KnnBreadthFirst, KnnDepthFirst, KnnRepeatedRnn, ParSearchAlgorithm, RnnClustered}, + cluster::{adapter::ParBallAdapter, ClusterIO, Csv, ParPartition}, + dataset::{AssociatesMetadata, AssociatesMetadataMut, ParDatasetIO}, + pancakes::{CodecData, SquishyBall}, + Ball, Cluster, Dataset, FlatVec, }; use distances::Number; -use super::{instances::Aligned, PathManager}; +use super::{ + instances::{Aligned, Hamming}, + PathManager, +}; -type I = Aligned; +type I = Aligned; type U = u32; type M = String; -type Co = FlatVec; -type B = Ball; -type Dec = CodecData; -type Sb = SquishyBall; +type Co = FlatVec; +type B = Ball; +type Dec = CodecData; +type Sb = SquishyBall; type Hits = Vec>; /// The group of types used for the datasets of aligned sequences. @@ -46,17 +49,15 @@ impl Group { let query_path = path_manager.queries_path(); if !query_path.exists() { // Serialize the queries to disk. - bincode::serialize_into(std::fs::File::create(&query_path).map_err(|e| e.to_string())?, &queries) - .map_err(|e| e.to_string())?; + let bytes = bitcode::encode(&queries).map_err(|e| e.to_string())?; + std::fs::write(path_manager.queries_path(), &bytes).map_err(|e| e.to_string())?; } let (query_ids, queries) = queries.into_iter().unzip(); + let metric = Hamming; let ball_path = path_manager.ball_path(); let ball = if ball_path.exists() { - // Deserialize the ball from disk. - ftlog::info!("Reading ball from {ball_path:?}"); - bincode::deserialize_from(std::fs::File::open(&ball_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())? + B::read_from(&ball_path)? } else { // Create the ball from scratch. ftlog::info!("Creating ball."); @@ -64,22 +65,22 @@ impl Group { let seed = Some(42); let indices = (0..uncompressed.cardinality()).collect::>(); - let mut ball = Ball::par_new(&uncompressed, &indices, 0, seed); - let depth_delta = ball.max_recursion_depth(); + let mut ball = Ball::par_new(&uncompressed, &metric, &indices, 0, seed) + .unwrap_or_else(|e| unreachable!("We ensured that indices is non-empty: {e}")); + let depth_delta = abd_clam::utils::max_recursion_depth(); - let criteria = |c: &Ball<_, _, _>| c.depth() < 1; - ball.par_partition(&uncompressed, &criteria, seed); + let criteria = |c: &B| c.depth() < 1; + ball.par_partition(&uncompressed, &metric, &criteria, seed); while ball.leaves().into_iter().any(|c| !c.is_singleton()) { depth += depth_delta; - let criteria = |c: &Ball<_, _, _>| c.depth() < depth; - ball.par_partition_further(&uncompressed, &criteria, seed); + let criteria = |c: &B| c.depth() < depth; + ball.par_partition_further(&uncompressed, &metric, &criteria, seed); } // Serialize the ball to disk. ftlog::info!("Writing ball to {ball_path:?}"); - bincode::serialize_into(std::fs::File::create(&ball_path).map_err(|e| e.to_string())?, &ball) - .map_err(|e| e.to_string())?; + ball.write_to(&ball_path)?; let num_leaves = ball.leaves().len(); ftlog::info!("Ball has {num_leaves} leaves."); @@ -94,47 +95,30 @@ impl Group { let squishy_ball_path = path_manager.squishy_ball_path(); let compressed_path = path_manager.compressed_path(); - let (squishy_ball, mut compressed) = if squishy_ball_path.exists() && compressed_path.exists() { + let (squishy_ball, compressed) = if squishy_ball_path.exists() && compressed_path.exists() { ftlog::info!("Reading squishy ball from {squishy_ball_path:?}"); - let squishy_ball = - bincode::deserialize_from(std::fs::File::open(&squishy_ball_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let squishy_ball = Sb::read_from(&squishy_ball_path)?; ftlog::info!("Reading compressed dataset from {compressed_path:?}"); - let codec_data = - bincode::deserialize_from(std::fs::File::open(&compressed_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let codec_data = Dec::par_read_from(&compressed_path)?; (squishy_ball, codec_data) } else { - ftlog::info!("Creating squishy ball and permuted dataset."); - let (mut squishy_ball, perm_data) = { - let (off_ball, data) = OffBall::par_from_ball_tree(ball.clone(), uncompressed.clone()); - let mut squishy_ball = SquishyBall::par_adapt_tree_iterative(off_ball, None); + ftlog::info!("Creating squishy ball and compressed dataset."); + let (squishy_ball, codec_data) = { + let (squishy_ball, data) = SquishyBall::par_from_ball_tree(ball.clone(), uncompressed.clone(), &metric); - // Set the costs of the squishy ball and write it to a CSV file. - ftlog::info!("Setting costs and writing pre-trim ball to CSV."); - squishy_ball.par_set_costs(&data); - squishy_ball.write_to_csv(&path_manager.pre_trim_csv_path())?; + // Write it to a CSV file. + squishy_ball.write_to_csv(&path_manager.squishy_ball_csv_path())?; - (squishy_ball, data) + (squishy_ball, data.with_metadata(uncompressed.metadata())?) }; - // Trim the squishy ball and write it to a CSV file. - ftlog::info!("Trimming squishy ball and writing to CSV."); - squishy_ball.trim(4); - squishy_ball.write_to_csv(&path_manager.squishy_csv_path())?; - let num_leaves = squishy_ball.leaves().len(); ftlog::info!("Built squishy ball with {num_leaves} leaves."); assert!(num_leaves > 1, "Squishy ball has only one leaf."); - // Create the compressed dataset and set its metadata. - ftlog::info!("Creating compressed dataset."); - let codec_data = CodecData::from_compressible(&perm_data, &squishy_ball) - .with_metadata(uncompressed.metadata().to_vec())?; - let num_bytes = codec_data .leaf_bytes() .iter() @@ -142,27 +126,15 @@ impl Group { .sum::(); ftlog::info!("Built compressed dataset with {num_bytes} leaf bytes."); - // Change the metadata type of the squishy ball to match the compressed dataset. - let squishy_ball = squishy_ball.with_metadata_type::(); - // Serialize the squishy ball and the compressed dataset to disk. ftlog::info!("Writing squishy ball to {squishy_ball_path:?}"); - bincode::serialize_into( - std::fs::File::create(&squishy_ball_path).map_err(|e| e.to_string())?, - &squishy_ball, - ) - .map_err(|e| e.to_string())?; + squishy_ball.write_to(&squishy_ball_path)?; ftlog::info!("Writing compressed dataset to {compressed_path:?}"); - bincode::serialize_into( - std::fs::File::create(&compressed_path).map_err(|e| e.to_string())?, - &codec_data, - ) - .map_err(|e| e.to_string())?; + codec_data.par_write_to(&compressed_path)?; (squishy_ball, codec_data) }; - compressed.set_metric(I::metric()); Ok(Self { path_manager, @@ -175,76 +147,69 @@ impl Group { }) } + fn bench_search(&self, num_queries: usize, alg_a: &Aco, alg_b: &Adec) -> Result, String> + where + Aco: ParSearchAlgorithm, + Adec: ParSearchAlgorithm, + { + let metric = &Hamming; + let name = alg_a.name(); + + let queries = &self.queries[..num_queries]; + ftlog::info!("Running benchmarks for compressive search on {num_queries} queries with {name}"); + + let uncompressed_start = std::time::Instant::now(); + let uncompressed_hits = alg_a.par_batch_search(&self.uncompressed, metric, &self.ball, queries); + let uncompressed_time = uncompressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); + ftlog::info!( + "Algorithm {name} took {uncompressed_time:.3e} seconds per query uncompressed time on {}", + self.path_manager.name() + ); + + let compressed_start = std::time::Instant::now(); + let compressed_hits = alg_b.par_batch_search(&self.compressed, metric, &self.squishy_ball, queries); + let compressed_time = compressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); + ftlog::info!( + "Algorithm {name} took {compressed_time:.3e} seconds per query compressed time on {}", + self.path_manager.name() + ); + + self.verify_hits(uncompressed_hits, compressed_hits)?; + + let slowdown = compressed_time / uncompressed_time; + Ok(vec![ + format!("uncompressed: {uncompressed_time:.4e}"), + format!("uncompressed_throughput: {:.4e}", 1.0 / uncompressed_time), + format!("compressed: {compressed_time:.4e}"), + format!("compressed_throughput: {:.4e}", 1.0 / compressed_time), + format!("slowdown: {slowdown:.4}"), + ]) + } + /// Run benchmarks for compressive search on the dataset. /// /// # Errors /// /// - If there is an error writing the times to disk. - pub fn bench_compressive_search(&mut self, num_queries: usize) -> Result<(), String> { - self.uncompressed.set_metric(Aligned::levenshtein_metric()); - self.compressed.set_metric(Aligned::levenshtein_metric()); - - let radius = 5; - let k = 10; - let algorithms = [ - // Algorithm::RnnLinear(radius), - Algorithm::RnnClustered(radius), - // Algorithm::KnnLinear(k), - Algorithm::KnnRepeatedRnn(k, 2), - Algorithm::KnnBreadthFirst(k), - Algorithm::KnnDepthFirst(k), - ]; - - let num_queries = num_queries.min(self.queries.len()); - let queries = &self.queries[..num_queries]; - ftlog::info!( - "Running benchmarks for compressive search on {num_queries} queries with {} algorithms", - algorithms.len() - ); + pub fn bench_compressive_search(&self, num_queries: usize) -> Result<(), String> { + let num_queries = Ord::min(num_queries, self.queries.len()); + ftlog::info!("Running benchmarks for compressive search on {num_queries} queries."); let mut times = HashMap::new(); - for (i, alg) in algorithms.iter().enumerate() { - ftlog::info!( - "Running algorithm {} ({}/{}) on {}", - alg.name(), - i + 1, - algorithms.len(), - self.path_manager.name() - ); - - let uncompressed_start = std::time::Instant::now(); - let uncompressed_hits = alg.par_batch_par_search(&self.uncompressed, &self.ball, queries); - let uncompressed_time = uncompressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); - ftlog::info!( - "Algorithm {} took {:.3e} seconds per query uncompressed time on {}", - alg.name(), - uncompressed_time, - self.path_manager.name() - ); - - let compressed_start = std::time::Instant::now(); - let compressed_hits = alg.par_batch_par_search(&self.compressed, &self.squishy_ball, queries); - let compressed_time = compressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); - ftlog::info!( - "Algorithm {} took {:.3e} seconds per query compressed time on {}", - alg.name(), - compressed_time, - self.path_manager.name() - ); - - self.verify_hits(uncompressed_hits, compressed_hits)?; - - let slowdown = compressed_time / uncompressed_time; - times.insert( - alg.name(), - ( - format!("uncompressed: {uncompressed_time:.4e}"), - format!("uncompressed_throughput: {:.4e}", 1.0 / uncompressed_time), - format!("compressed: {compressed_time:.4e}"), - format!("compressed_throughput: {:.4e}", 1.0 / compressed_time), - format!("slowdown: {slowdown:.4}"), - ), - ); + for radius in [5, 10, 100] { + let times_inner = self.bench_search(num_queries, &RnnClustered(radius), &RnnClustered(radius))?; + times.insert(format!("RnnClustered({radius})"), times_inner); + } + + for k in [1, 10, 100] { + let times_inner = self.bench_search(num_queries, &KnnRepeatedRnn(k, 2), &KnnRepeatedRnn(k, 2))?; + times.insert(format!("KnnRepeatedRnn({k}, 2)"), times_inner); + + let times_inner = self.bench_search(num_queries, &KnnBreadthFirst(k), &KnnBreadthFirst(k))?; + times.insert(format!("KnnBreadthFirst({k})"), times_inner); + + let times_inner = self.bench_search(num_queries, &KnnDepthFirst(k), &KnnDepthFirst(k))?; + times.insert(format!("KnnDepthFirst({k})"), times_inner); } ftlog::info!("Writing times to disk."); @@ -254,9 +219,6 @@ impl Group { ) .map_err(|e| e.to_string())?; - self.compressed.set_metric(Aligned::metric()); - self.uncompressed.set_metric(Aligned::metric()); - Ok(()) } diff --git a/crates/results/cakes/src/data/tree/ann_set.rs b/crates/results/cakes/src/data/tree/ann_set.rs index 05c66453c..acaa1dbd8 100644 --- a/crates/results/cakes/src/data/tree/ann_set.rs +++ b/crates/results/cakes/src/data/tree/ann_set.rs @@ -3,23 +3,26 @@ use std::collections::HashMap; use abd_clam::{ - adapter::{ParAdapter, ParBallAdapter}, - cakes::{Algorithm, CodecData, Decompressible, OffBall, SquishyBall}, - cluster::WriteCsv, - partition::ParPartition, - BalancedBall, Ball, Cluster, Dataset, FlatVec, MetricSpace, + cakes::{KnnBreadthFirst, KnnDepthFirst, KnnRepeatedRnn, ParSearchAlgorithm, RnnClustered}, + cluster::{adapter::ParBallAdapter, ClusterIO, Csv, ParPartition}, + dataset::{AssociatesMetadata, AssociatesMetadataMut, ParDatasetIO}, + pancakes::{CodecData, SquishyBall}, + Ball, Cluster, FlatVec, }; use distances::Number; -use super::{instances::MemberSet, PathManager}; +use super::{ + instances::{Jaccard, MemberSet}, + PathManager, +}; -type I = MemberSet; +type I = MemberSet; type U = f32; type M = usize; -type Co = FlatVec; -type B = Ball; -type Dec = CodecData; -type Sb = SquishyBall; +type Co = FlatVec; +type B = Ball; +type Dec = CodecData; +type Sb = SquishyBall; type Hits = Vec>; /// The group of types used for the datasets of named member sets. @@ -54,76 +57,36 @@ impl Group { let query_path = path_manager.queries_path(); if !query_path.exists() { // Serialize the queries to disk. - bincode::serialize_into(std::fs::File::create(&query_path).map_err(|e| e.to_string())?, &queries) - .map_err(|e| e.to_string())?; + let bytes = bitcode::encode(&queries).map_err(|e| e.to_string())?; + std::fs::write(&query_path, &bytes).map_err(|e| e.to_string())?; } let (query_ids, queries) = queries.into_iter().unzip(); let gt_path = path_manager.ground_truth_path(); if !gt_path.exists() { // Serialize the ground truth to disk. - bincode::serialize_into( - std::fs::File::create(>_path).map_err(|e| e.to_string())?, - &ground_truth, - ) - .map_err(|e| e.to_string())?; + let bytes = bitcode::encode(&ground_truth).map_err(|e| e.to_string())?; + std::fs::write(>_path, &bytes).map_err(|e| e.to_string())?; } + let metric = Jaccard; + let ball_path = path_manager.ball_path(); let ball = if ball_path.exists() { - // Deserialize the ball from disk. - ftlog::info!("Deserializing ball from {ball_path:?}"); - bincode::deserialize_from(std::fs::File::open(&ball_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())? + B::read_from(&ball_path)? } else { // Create the ball from scratch. - ftlog::info!("Creating ball with balanced partition."); - let mut max_depth = 0; + ftlog::info!("Creating Ball with default partition."); let seed = Some(42); - - let indices = (0..uncompressed.cardinality()).collect::>(); - let mut ball = BalancedBall::par_new(&uncompressed, &indices, 0, seed); - let depth_delta = ball.max_recursion_depth(); - - loop { - max_depth += depth_delta; - let criteria = - |c: &BalancedBall<_, _, _>| (c.depth() < max_depth && c.radius() > 0.75) || c.cardinality() > 1_000; - ball.par_partition_further(&uncompressed, &criteria, seed); - - // If there are no leaves at the current maximum depth, break - if !ball.leaves().into_iter().any(|c| c.depth() == max_depth) { - break; - } - } - - let num_leaves = ball.leaves().len(); - ftlog::info!("Balanced ball has {} leaves", num_leaves); - - ftlog::info!("Switching to default partition."); - let mut ball = Ball::par_from_balanced_ball(ball); - - for leaf in ball.leaves_mut() { - max_depth = leaf.depth(); - loop { - max_depth += depth_delta; - let criteria = |c: &Ball<_, _, _>| c.depth() < max_depth; - leaf.par_partition_further(&uncompressed, &criteria, seed); - - // If there are no leaves at the current maximum depth, break - if !leaf.leaves().into_iter().any(|c| c.depth() == max_depth) { - break; - } - } - } + let criteria = |c: &B| c.cardinality() > 1; + let ball = Ball::par_new_tree(&uncompressed, &metric, &criteria, seed); let num_leaves = ball.leaves().len(); ftlog::info!("Ball has {} leaves", num_leaves); // Serialize the ball to disk. ftlog::info!("Serializing ball to {ball_path:?}"); - bincode::serialize_into(std::fs::File::create(&ball_path).map_err(|e| e.to_string())?, &ball) - .map_err(|e| e.to_string())?; + ball.write_to(&ball_path)?; // Write the ball to a CSV file. ftlog::info!("Writing ball to CSV file."); @@ -135,45 +98,28 @@ impl Group { let squishy_ball_path = path_manager.squishy_ball_path(); let compressed_path = path_manager.compressed_path(); - let (squishy_ball, mut compressed) = if squishy_ball_path.exists() && compressed_path.exists() { + let (squishy_ball, compressed) = if squishy_ball_path.exists() && compressed_path.exists() { ftlog::info!("Deserializing squishy ball from {squishy_ball_path:?}"); - let squishy_ball = - bincode::deserialize_from(std::fs::File::open(&squishy_ball_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let squishy_ball = Sb::read_from(&squishy_ball_path)?; ftlog::info!("Deserializing compressed dataset from {compressed_path:?}"); - let codec_data = - bincode::deserialize_from(std::fs::File::open(&compressed_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let codec_data = Dec::par_read_from(&compressed_path)?; (squishy_ball, codec_data) } else { - ftlog::info!("Creating squishy ball and permuted dataset."); - let (mut squishy_ball, perm_data) = { - let (off_ball, data) = OffBall::par_from_ball_tree(ball.clone(), uncompressed.clone()); - let mut squishy_ball = SquishyBall::par_adapt_tree_iterative(off_ball, None); + ftlog::info!("Creating squishy ball and compressed dataset."); + let (squishy_ball, codec_data) = { + let (squishy_ball, data) = SquishyBall::par_from_ball_tree(ball.clone(), uncompressed.clone(), &metric); - // Set the costs of the squishy ball and write it to a CSV file. - ftlog::info!("Setting costs and writing pre-trim ball to CSV file."); - squishy_ball.par_set_costs(&data); - squishy_ball.write_to_csv(&path_manager.pre_trim_csv_path())?; + // Write it to a CSV file. + squishy_ball.write_to_csv(&path_manager.squishy_ball_csv_path())?; - (squishy_ball, data) + (squishy_ball, data.with_metadata(uncompressed.metadata())?) }; - // Trim the squishy ball and write it to a CSV file. - ftlog::info!("Trimming squishy ball and writing to CSV file."); - squishy_ball.trim(4); - squishy_ball.write_to_csv(&path_manager.squishy_csv_path())?; - let num_leaves = squishy_ball.leaves().len(); ftlog::info!("Squishy ball has {num_leaves} leaves"); - // Create the compressed dataset and set its metadata. - ftlog::info!("Creating compressed dataset."); - let codec_data = CodecData::from_compressible(&perm_data, &squishy_ball) - .with_metadata(uncompressed.metadata().to_vec())?; - let num_bytes = codec_data .leaf_bytes() .iter() @@ -183,22 +129,13 @@ impl Group { // Serialize the squishy ball and the compressed dataset to disk. ftlog::info!("Serializing squishy ball to {squishy_ball_path:?}"); - bincode::serialize_into( - std::fs::File::create(&squishy_ball_path).map_err(|e| e.to_string())?, - &squishy_ball, - ) - .map_err(|e| e.to_string())?; + squishy_ball.write_to(&squishy_ball_path)?; ftlog::info!("Serializing compressed dataset to {compressed_path:?}"); - bincode::serialize_into( - std::fs::File::create(&compressed_path).map_err(|e| e.to_string())?, - &codec_data, - ) - .map_err(|e| e.to_string())?; + codec_data.par_write_to(&compressed_path)?; (squishy_ball, codec_data) }; - compressed.set_metric(I::metric()); Ok(Self { path_manager, @@ -212,75 +149,72 @@ impl Group { }) } + fn bench_search(&self, num_queries: usize, alg_a: &Aco, alg_b: &Adec) -> Result, String> + where + Aco: ParSearchAlgorithm, + Adec: ParSearchAlgorithm, + { + let metric = &Jaccard; + let name = alg_a.name(); + + let queries = &self.queries[..num_queries]; + ftlog::info!("Running benchmarks for compressive search on {num_queries} queries with {name}"); + + let uncompressed_start = std::time::Instant::now(); + let uncompressed_hits = alg_a.par_batch_search(&self.uncompressed, metric, &self.ball, queries); + let uncompressed_time = uncompressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); + ftlog::info!( + "Algorithm {name} took {uncompressed_time:.3e} seconds per query uncompressed time on {}", + self.path_manager.name() + ); + + let compressed_start = std::time::Instant::now(); + let compressed_hits = alg_b.par_batch_search(&self.compressed, metric, &self.squishy_ball, queries); + let compressed_time = compressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); + ftlog::info!( + "Algorithm {name} took {compressed_time:.3e} seconds per query compressed time on {}", + self.path_manager.name() + ); + + self.verify_hits(uncompressed_hits, compressed_hits)?; + + let slowdown = compressed_time / uncompressed_time; + Ok(vec![ + format!("uncompressed: {uncompressed_time:.4e}"), + format!("uncompressed_throughput: {:.4e}", 1.0 / uncompressed_time), + format!("compressed: {compressed_time:.4e}"), + format!("compressed_throughput: {:.4e}", 1.0 / compressed_time), + format!("slowdown: {slowdown:.4}"), + ]) + } + /// Run benchmarks for compressive search on the dataset. /// /// # Errors /// /// - If there is an error writing the times to disk. pub fn bench_compressive_search(&self, num_queries: usize) -> Result<(), String> { - let radius = 0.02; - let k = 10; - let algorithms = [ - // Algorithm::RnnLinear(radius), - Algorithm::RnnClustered(radius), - // Algorithm::KnnLinear(k), - // Algorithm::KnnRepeatedRnn(k, 2_f32), - Algorithm::KnnBreadthFirst(k), - Algorithm::KnnDepthFirst(k), - ]; - - let num_queries = num_queries.min(self.queries.len()); - let queries = &self.queries[..num_queries]; - ftlog::info!( - "Running benchmarks for compressive search on {num_queries} queries with {} algorithms", - algorithms.len() - ); + let num_queries = Ord::min(num_queries, self.queries.len()); + ftlog::info!("Running benchmarks for compressive search on {num_queries} queries with"); let mut times = HashMap::new(); - for (i, alg) in algorithms.iter().enumerate() { - ftlog::info!( - "Running algorithm {} ({}/{}) on {}", - alg.name(), - i + 1, - algorithms.len(), - self.path_manager.name() - ); - - let uncompressed_start = std::time::Instant::now(); - let uncompressed_hits = alg.par_batch_par_search(&self.uncompressed, &self.ball, queries); - let uncompressed_time = uncompressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); - ftlog::info!( - "Algorithm {} took {:.3e} seconds per query uncompressed time on {}", - alg.name(), - uncompressed_time, - self.path_manager.name() - ); - - let compressed_start = std::time::Instant::now(); - let compressed_hits = alg.par_batch_par_search(&self.compressed, &self.squishy_ball, queries); - let compressed_time = compressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); - ftlog::info!( - "Algorithm {} took {:.3e} seconds per query compressed time on {}", - alg.name(), - compressed_time, - self.path_manager.name() - ); - - self.verify_hits(uncompressed_hits, compressed_hits)?; - - let slowdown = compressed_time / uncompressed_time; - times.insert( - alg.name(), - ( - format!("uncompressed: {uncompressed_time:.4e}"), - format!("uncompressed_throughput: {:.4e}", 1.0 / uncompressed_time), - format!("compressed: {compressed_time:.4e}"), - format!("compressed_throughput: {:.4e}", 1.0 / compressed_time), - format!("slowdown: {slowdown:.4}"), - ), - ); + for radius in [0.01, 0.02, 0.1] { + let times_inner = self.bench_search(num_queries, &RnnClustered(radius), &RnnClustered(radius))?; + times.insert(format!("RnnClustered({radius})"), times_inner); + } + + for k in [1, 10, 100] { + let times_inner = self.bench_search(num_queries, &KnnRepeatedRnn(k, 2.0), &KnnRepeatedRnn(k, 2.0))?; + times.insert(format!("KnnRepeatedRnn({k}, 2)"), times_inner); + + let times_inner = self.bench_search(num_queries, &KnnBreadthFirst(k), &KnnBreadthFirst(k))?; + times.insert(format!("KnnBreadthFirst({k})"), times_inner); + + let times_inner = self.bench_search(num_queries, &KnnDepthFirst(k), &KnnDepthFirst(k))?; + times.insert(format!("KnnDepthFirst({k})"), times_inner); } + ftlog::info!("Writing times to disk."); serde_json::to_writer_pretty( std::fs::File::create(self.path_manager.times_path()).map_err(|e| e.to_string())?, ×, diff --git a/crates/results/cakes/src/data/tree/instances/aligned_sequence.rs b/crates/results/cakes/src/data/tree/instances/aligned_sequence.rs index 3cf4af76d..b562df6fc 100644 --- a/crates/results/cakes/src/data/tree/instances/aligned_sequence.rs +++ b/crates/results/cakes/src/data/tree/instances/aligned_sequence.rs @@ -1,92 +1,52 @@ //! Aligned sequence with Hamming distance and substitutions for Edits. -use abd_clam::cakes::{Decodable, Encodable}; -use distances::number::UInt; +use abd_clam::{ + metric::ParMetric, + msa::Sequence, + pancakes::{Decodable, Encodable}, + Metric, +}; +use distances::{number::UInt, Number}; use serde::{Deserialize, Serialize}; /// A sequence from a FASTA file. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Aligned { - /// The aligned sequence. - seq: String, - /// To keep the type parameter. - _phantom: std::marker::PhantomData, -} - -impl Default for Aligned { - fn default() -> Self { - Self { - seq: String::default(), - _phantom: std::marker::PhantomData, - } - } -} - -impl Aligned { - /// Returns the Hamming metric for `Aligned` sequences. - #[must_use] - pub fn metric() -> abd_clam::Metric { - let distance_function = |first: &Self, second: &Self| { - U::from( - first - .seq - .chars() - .zip(second.seq.chars()) - .filter(|(a, b)| a != b) - .count(), - ) - }; - abd_clam::Metric::new(distance_function, false) - } - - /// Returns the Levenshtein metric for unaligned version of the `Aligned` sequences. - #[must_use] - pub fn levenshtein_metric() -> abd_clam::Metric { - let distance_function = |first: &Self, second: &Self| { - U::from(stringzilla::sz::edit_distance( - first.as_unaligned(), - second.as_unaligned(), - )) - }; - abd_clam::Metric::new(distance_function, false) - } +#[derive(Debug, Clone, Serialize, Deserialize, Default, bitcode::Encode, bitcode::Decode)] +pub struct Aligned(String); - pub fn as_unaligned(&self) -> String { - self.seq.chars().filter(|&c| c != '-' && c != '.').collect() - } -} - -impl AsRef for Aligned { +impl AsRef for Aligned { fn as_ref(&self) -> &str { - &self.seq + &self.0 } } -impl From<&str> for Aligned { +impl From<&str> for Aligned { fn from(seq: &str) -> Self { Self::from(seq.to_string()) } } -impl From for Aligned { +impl From for Aligned { fn from(seq: String) -> Self { - Self { - seq, - _phantom: std::marker::PhantomData, - } + Self(seq) } } -impl Encodable for Aligned { +impl<'a, T: Number> From> for Aligned { + fn from(seq: Sequence<'a, T>) -> Self { + Self(seq.seq().to_string()) + } +} + +impl Encodable for Aligned { fn as_bytes(&self) -> Box<[u8]> { - self.seq.as_bytes().into() + self.0.as_bytes().into() } #[allow(clippy::cast_possible_truncation)] fn encode(&self, reference: &Self) -> Box<[u8]> { - self.seq + self.0 .chars() - .zip(reference.seq.chars()) + .zip(reference.0.chars()) .enumerate() .filter_map(|(i, (c, r))| if c == r { None } else { Some((i as u16, c)) }) .flat_map(|(i, c)| { @@ -101,7 +61,7 @@ impl Encodable for Aligned { } /// This uses the Needleman-Wunsch algorithm to decode strings. -impl Decodable for Aligned { +impl Decodable for Aligned { fn from_bytes(bytes: &[u8]) -> Self { Self::from( String::from_utf8(bytes.to_vec()).unwrap_or_else(|e| unreachable!("Could not cast back to string: {e:?}")), @@ -109,7 +69,7 @@ impl Decodable for Aligned { } fn decode(reference: &Self, bytes: &[u8]) -> Self { - let mut sequence = reference.seq.chars().collect::>(); + let mut sequence = reference.0.chars().collect::>(); for chunk in bytes.chunks_exact(3) { let i = u16::from_be_bytes([chunk[0], chunk[1]]) as usize; @@ -120,3 +80,38 @@ impl Decodable for Aligned { Self::from(sequence.into_iter().collect::()) } } + +/// The `Hamming` distance metric. +pub struct Hamming; + +impl, U: UInt> Metric for Hamming { + fn distance(&self, a: &I, b: &I) -> U { + distances::vectors::hamming(a.as_ref().as_ref(), b.as_ref().as_ref()) + } + + fn name(&self) -> &str { + "euclidean" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } +} + +impl + Send + Sync, U: UInt> ParMetric for Hamming {} diff --git a/crates/results/cakes/src/data/tree/instances/member_set.rs b/crates/results/cakes/src/data/tree/instances/member_set.rs index ee66f863f..97f2d60d9 100644 --- a/crates/results/cakes/src/data/tree/instances/member_set.rs +++ b/crates/results/cakes/src/data/tree/instances/member_set.rs @@ -4,84 +4,49 @@ use core::hash::Hash; use std::collections::HashSet; -use abd_clam::cakes::{Decodable, Encodable}; +use abd_clam::{ + metric::ParMetric, + pancakes::{Decodable, Encodable}, + Metric, +}; use distances::number::{Float, UInt}; use serde::{Deserialize, Serialize}; /// A set of named members. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct MemberSet { - /// The members of the set. - inner: HashSet, - /// To keep the type parameter. - _phantom: std::marker::PhantomData, -} - -impl Default for MemberSet { - fn default() -> Self { - Self { - inner: HashSet::default(), - _phantom: std::marker::PhantomData, - } - } -} - -impl MemberSet { - /// Returns the Jaccard distance metric for `MemberSet`s. - #[must_use] - pub fn metric() -> abd_clam::Metric { - let distance_function = |first: &Self, second: &Self| { - let intersection = first.inner.intersection(&second.inner).count(); - let union = first.inner.len() + second.inner.len() - intersection; - let sim = if union == 0 { - F::ZERO - } else { - F::from(intersection) / F::from(union) - }; - F::ONE - sim - }; - - abd_clam::Metric::new(distance_function, false) - } -} +#[derive(Debug, Clone, Serialize, Deserialize, Default, bitcode::Encode, bitcode::Decode)] +pub struct MemberSet(HashSet); -impl From<&[T]> for MemberSet { +impl From<&[T]> for MemberSet { fn from(items: &[T]) -> Self { - Self { - inner: items.iter().copied().collect(), - _phantom: std::marker::PhantomData, - } + Self(items.iter().copied().collect()) } } -impl From<&Vec> for MemberSet { +impl From<&Vec> for MemberSet { fn from(items: &Vec) -> Self { Self::from(items.as_slice()) } } -impl From<&HashSet> for MemberSet { +impl From<&HashSet> for MemberSet { fn from(items: &HashSet) -> Self { - Self { - inner: items.clone(), - _phantom: std::marker::PhantomData, - } + Self(items.clone()) } } -impl AsRef> for MemberSet { +impl AsRef> for MemberSet { fn as_ref(&self) -> &HashSet { - &self.inner + &self.0 } } -impl From<&MemberSet> for Vec { - fn from(set: &MemberSet) -> Self { - set.inner.iter().copied().collect() +impl From<&MemberSet> for Vec { + fn from(set: &MemberSet) -> Self { + set.0.iter().copied().collect() } } -impl Encodable for MemberSet { +impl Encodable for MemberSet { fn as_bytes(&self) -> Box<[u8]> { Vec::from(self) .into_iter() @@ -94,11 +59,11 @@ impl Encodable for MemberSet { fn encode(&self, reference: &Self) -> Box<[u8]> { let mut bytes = vec![]; - let new_items = self.inner.difference(&reference.inner).copied().collect::>(); + let new_items = self.0.difference(&self.0).copied().collect::>(); bytes.extend_from_slice(&new_items.len().to_le_bytes()); bytes.extend(new_items.into_iter().flat_map(T::to_le_bytes)); - let removed_items = reference.inner.difference(&self.inner).copied().collect::>(); + let removed_items = reference.0.difference(&self.0).copied().collect::>(); bytes.extend_from_slice(&removed_items.len().to_le_bytes()); bytes.extend(removed_items.into_iter().flat_map(T::to_le_bytes)); @@ -106,21 +71,14 @@ impl Encodable for MemberSet { } } -impl Decodable for MemberSet { +impl Decodable for MemberSet { fn from_bytes(bytes: &[u8]) -> Self { - let items = bytes - .chunks_exact(T::NUM_BYTES) - .map(T::from_le_bytes) - .collect::>(); - - Self { - inner: items, - _phantom: std::marker::PhantomData, - } + let items = bytes.chunks_exact(T::NUM_BYTES).map(T::from_le_bytes).collect(); + Self(items) } fn decode(reference: &Self, bytes: &[u8]) -> Self { - let mut inner = reference.inner.clone(); + let mut inner = reference.0.clone(); let mut offset = 0; @@ -134,9 +92,48 @@ impl Decodable for MemberSet { inner.remove(&abd_clam::utils::read_number(bytes, &mut offset)); } - Self { - inner, - _phantom: std::marker::PhantomData, + Self(inner) + } +} + +/// The `Jaccard` distance metric. +pub struct Jaccard; + +impl Metric, U> for Jaccard { + fn distance(&self, a: &MemberSet, b: &MemberSet) -> U { + let intersection = a.0.intersection(&b.0).count(); + let union = a.0.len() + b.0.len() - intersection; + if union == 0 { + U::ZERO + } else { + let sim = U::from(intersection) / U::from(union); + U::ONE - sim } } + + fn name(&self) -> &str { + "euclidean" + } + + fn has_identity(&self) -> bool { + true + } + + fn has_non_negativity(&self) -> bool { + true + } + + fn has_symmetry(&self) -> bool { + true + } + + fn obeys_triangle_inequality(&self) -> bool { + true + } + + fn is_expensive(&self) -> bool { + false + } } + +impl ParMetric, U> for Jaccard {} diff --git a/crates/results/cakes/src/data/tree/instances/mod.rs b/crates/results/cakes/src/data/tree/instances/mod.rs index e2b12b009..37e741513 100644 --- a/crates/results/cakes/src/data/tree/instances/mod.rs +++ b/crates/results/cakes/src/data/tree/instances/mod.rs @@ -4,6 +4,6 @@ mod aligned_sequence; mod member_set; mod unaligned_sequence; -pub use aligned_sequence::Aligned; -pub use member_set::MemberSet; +pub use aligned_sequence::{Aligned, Hamming}; +pub use member_set::{Jaccard, MemberSet}; pub use unaligned_sequence::Unaligned; diff --git a/crates/results/cakes/src/data/tree/instances/unaligned_sequence.rs b/crates/results/cakes/src/data/tree/instances/unaligned_sequence.rs index 939c3dd98..e1aded512 100644 --- a/crates/results/cakes/src/data/tree/instances/unaligned_sequence.rs +++ b/crates/results/cakes/src/data/tree/instances/unaligned_sequence.rs @@ -1,73 +1,62 @@ //! Unaligned sequence with Levenshtein distance and Needleman-Wunsch Edits. -use abd_clam::cakes::{Decodable, Encodable}; -use distances::number::UInt; +use abd_clam::{ + msa::Sequence, + pancakes::{Decodable, Encodable}, +}; +use distances::Number; use serde::{Deserialize, Serialize}; /// A sequence from a FASTA file. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Unaligned { - seq: String, - _phantom: std::marker::PhantomData, -} +#[derive(Debug, Clone, Serialize, Deserialize, Default, bitcode::Encode, bitcode::Decode)] +pub struct Unaligned(String); -impl Default for Unaligned { - fn default() -> Self { - Self { - seq: String::default(), - _phantom: std::marker::PhantomData, - } +impl AsRef for Unaligned { + fn as_ref(&self) -> &str { + &self.0 } } -impl Unaligned { - /// Returns the Levenshtein metric for `Unaligned` sequences. - #[must_use] - pub fn metric() -> abd_clam::Metric { - let distance_function = - |first: &Self, second: &Self| U::from(stringzilla::sz::edit_distance(first.as_ref(), second.as_ref())); - abd_clam::Metric::new(distance_function, true) +impl AsRef<[u8]> for Unaligned { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() } } -impl AsRef for Unaligned { - fn as_ref(&self) -> &str { - &self.seq - } -} - -impl From<&str> for Unaligned { +impl From<&str> for Unaligned { fn from(seq: &str) -> Self { - let seq = seq.chars().filter(|&c| c != '-' && c != '.').collect(); - Self { - seq, - _phantom: std::marker::PhantomData, - } + Self(seq.chars().filter(|&c| c != '-' && c != '.').collect()) } } -impl From for Unaligned { +impl From for Unaligned { fn from(seq: String) -> Self { Self::from(seq.as_str()) } } -impl From> for String { - fn from(seq: Unaligned) -> Self { - seq.seq +impl<'a, T: Number> From> for Unaligned { + fn from(seq: Sequence<'a, T>) -> Self { + Self::from(seq.as_ref()) + } +} + +impl From for String { + fn from(seq: Unaligned) -> Self { + seq.0 } } -impl Encodable for Unaligned { +impl Encodable for Unaligned { fn as_bytes(&self) -> Box<[u8]> { - self.seq.as_bytes().into() + self.0.as_bytes().into() } fn encode(&self, reference: &Self) -> Box<[u8]> { let (x, y) = (self.as_ref(), reference.as_ref()); let penalties = distances::strings::Penalties::default(); - let table = distances::strings::needleman_wunsch::compute_table::(x, y, penalties); + let table = distances::strings::needleman_wunsch::compute_table::(x, y, penalties); #[allow(clippy::tuple_array_conversions)] let (x, y) = distances::strings::needleman_wunsch::trace_back_recursive(&table, [x, y]); @@ -77,7 +66,7 @@ impl Encodable for Unaligned { } /// This uses the Needleman-Wunsch algorithm to decode strings. -impl Decodable for Unaligned { +impl Decodable for Unaligned { fn from_bytes(bytes: &[u8]) -> Self { let seq = String::from_utf8(bytes.to_vec()).unwrap_or_else(|e| unreachable!("Could not cast back to string: {e:?}")); diff --git a/crates/results/cakes/src/data/tree/mod.rs b/crates/results/cakes/src/data/tree/mod.rs index 3955b410d..c56b95860 100644 --- a/crates/results/cakes/src/data/tree/mod.rs +++ b/crates/results/cakes/src/data/tree/mod.rs @@ -4,6 +4,7 @@ use std::path::{Path, PathBuf}; use abd_clam::FlatVec; use bio::io::fastq::Result; +use instances::{Aligned, MemberSet, Unaligned}; pub mod instances; @@ -28,8 +29,8 @@ impl Tree { pub fn new_unaligned( name: &str, out_dir: &Path, - data: FlatVec, u32, String>, - queries: Vec<(String, instances::Unaligned)>, + data: FlatVec, + queries: Vec<(String, Unaligned)>, ) -> Result { let path_manager = PathManager { name: name.to_string(), @@ -46,8 +47,8 @@ impl Tree { pub fn new_aligned( name: &str, out_dir: &Path, - data: FlatVec, u32, String>, - queries: Vec<(String, instances::Aligned)>, + data: FlatVec, + queries: Vec<(String, Aligned)>, ) -> Result { let path_manager = PathManager { name: name.to_string(), @@ -64,8 +65,8 @@ impl Tree { pub fn new_ann_set( name: &str, out_dir: &Path, - data: FlatVec, f32, usize>, - queries: Vec<(usize, instances::MemberSet)>, + data: FlatVec, usize>, + queries: Vec<(usize, MemberSet)>, ground_truth: Vec>, ) -> Result { let path_manager = PathManager { @@ -99,35 +100,51 @@ pub struct PathManager { } impl PathManager { + /// Creates a new `PathManager`. + #[allow(dead_code)] + #[must_use] + pub fn new>(name: &str, out_dir: P) -> Self { + Self { + name: name.to_string(), + out_dir: out_dir.as_ref().to_path_buf(), + } + } + /// The name of the dataset. + #[must_use] pub fn name(&self) -> &str { &self.name } /// The directory where the dataset is stored. + #[must_use] pub fn out_dir(&self) -> &Path { &self.out_dir } /// The path to the binary file containing the `Ball` tree. + #[must_use] pub fn ball_path(&self) -> PathBuf { let name = format!("{}.ball", self.name()); self.out_dir().join(name) } /// The path to the binary file containing the `SquishyBall` tree. + #[must_use] pub fn squishy_ball_path(&self) -> PathBuf { let name = format!("{}.squishy_ball", self.name()); self.out_dir().join(name) } /// The path to the binary file containing the compressed data. + #[must_use] pub fn compressed_path(&self) -> PathBuf { let name = format!("{}.compressed", self.name()); self.out_dir().join(name) } /// The path to the binary file containing the queries. + #[must_use] pub fn queries_path(&self) -> PathBuf { let name = format!("{}.queries", self.name()); self.out_dir().join(name) @@ -136,32 +153,54 @@ impl PathManager { /// The path to the binary file containing the ground truth. /// /// This is only relevant for ANN sets. + #[must_use] pub fn ground_truth_path(&self) -> PathBuf { let name = format!("{}.ground_truth", self.name()); self.out_dir().join(name) } /// The path to the CSV file containing the `Ball` tree. + #[must_use] pub fn ball_csv_path(&self) -> PathBuf { let name = format!("{}_ball.csv", self.name()); self.out_dir().join(name) } /// The path to the CSV file containing the `SquishyBall` tree before it is trimmed. - pub fn pre_trim_csv_path(&self) -> PathBuf { + #[must_use] + pub fn squishy_ball_csv_path(&self) -> PathBuf { let name = format!("{}_pre_trim.csv", self.name()); self.out_dir().join(name) } - /// The path to the CSV file containing the `SquishyBall` tree after it is trimmed. - pub fn squishy_csv_path(&self) -> PathBuf { - let name = format!("{}_squishy.csv", self.name()); - self.out_dir().join(name) - } - /// Path to json file containing the times taken to search the dataset. + #[must_use] pub fn times_path(&self) -> PathBuf { let name = format!("{}_times.json", self.name()); self.out_dir().join(name) } + + /// Path to file containing tree used toe making an MSA. + #[allow(dead_code)] + #[must_use] + pub fn msa_ball_path(&self) -> PathBuf { + let name = format!("{}_ball.msa", self.name()); + self.out_dir().join(name) + } + + /// Path to file containing containing the MSA of the dataset. + #[allow(dead_code)] + #[must_use] + pub fn msa_data_path(&self) -> PathBuf { + let name = format!("{}.msa", self.name()); + self.out_dir().join(name) + } + + /// Path to file containing the MSA of the dataset in FASTA format. + #[allow(dead_code)] + #[must_use] + pub fn msa_fasta_path(&self) -> PathBuf { + let name = format!("{}_msa.fasta", self.name()); + self.out_dir().join(name) + } } diff --git a/crates/results/cakes/src/data/tree/unaligned.rs b/crates/results/cakes/src/data/tree/unaligned.rs index c1a206188..7e51cdb03 100644 --- a/crates/results/cakes/src/data/tree/unaligned.rs +++ b/crates/results/cakes/src/data/tree/unaligned.rs @@ -3,23 +3,24 @@ use std::collections::HashMap; use abd_clam::{ - adapter::{ParAdapter, ParBallAdapter}, - cakes::{Algorithm, CodecData, Decompressible, OffBall, SquishyBall}, - cluster::WriteCsv, - partition::ParPartition, - Ball, Cluster, Dataset, FlatVec, MetricSpace, + cakes::{KnnBreadthFirst, KnnDepthFirst, KnnRepeatedRnn, ParSearchAlgorithm, RnnClustered}, + cluster::{adapter::ParBallAdapter, ClusterIO, Csv, ParPartition}, + dataset::{AssociatesMetadata, AssociatesMetadataMut, ParDatasetIO}, + metric::Levenshtein, + pancakes::{CodecData, SquishyBall}, + Ball, Cluster, Dataset, FlatVec, }; use distances::Number; use super::{instances::Unaligned, PathManager}; -type I = Unaligned; +type I = Unaligned; type U = u32; type M = String; -type Co = FlatVec; -type B = Ball; -type Dec = CodecData; -type Sb = SquishyBall; +type Co = FlatVec; +type B = Ball; +type Dec = CodecData; +type Sb = SquishyBall; type Hits = Vec>; /// The group of types used for the datasets of unaligned sequences. @@ -43,20 +44,19 @@ impl Group { /// - If there is an error reading/writing serialized data to/from disk. /// - If there is an error writing the trees to csv files. pub fn new(path_manager: PathManager, uncompressed: Co, queries: Vec<(M, I)>) -> Result { + let metric = Levenshtein; let query_path = path_manager.queries_path(); if !query_path.exists() { // Serialize the queries to disk. - bincode::serialize_into(std::fs::File::create(&query_path).map_err(|e| e.to_string())?, &queries) - .map_err(|e| e.to_string())?; + let bytes = bitcode::encode(&queries).map_err(|e| e.to_string())?; + std::fs::write(&query_path, &bytes).map_err(|e| e.to_string())?; } let (query_ids, queries) = queries.into_iter().unzip(); let ball_path = path_manager.ball_path(); let ball = if ball_path.exists() { ftlog::info!("Reading ball from {ball_path:?}"); - // Deserialize the ball from disk. - bincode::deserialize_from(std::fs::File::open(&ball_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())? + B::read_from(&ball_path)? } else { // Create the ball from scratch. ftlog::info!("Building ball."); @@ -64,16 +64,17 @@ impl Group { let seed = Some(42); let indices = (0..uncompressed.cardinality()).collect::>(); - let mut ball = Ball::par_new(&uncompressed, &indices, 0, seed); - let depth_delta = ball.max_recursion_depth(); + let mut ball = Ball::par_new(&uncompressed, &metric, &indices, 0, seed) + .unwrap_or_else(|e| unreachable!("We ensured that indices is non-empty: {e}")); + let depth_delta = abd_clam::utils::max_recursion_depth(); - let criteria = |c: &Ball<_, _, _>| c.depth() < 1; - ball.par_partition(&uncompressed, &criteria, seed); + let criteria = |c: &B| c.depth() < 1; + ball.par_partition(&uncompressed, &metric, &criteria, seed); while ball.leaves().into_iter().any(|c| !c.is_singleton()) { depth += depth_delta; - let criteria = |c: &Ball<_, _, _>| c.depth() < depth; - ball.par_partition_further(&uncompressed, &criteria, seed); + let criteria = |c: &B| c.depth() < depth; + ball.par_partition_further(&uncompressed, &metric, &criteria, seed); } let num_leaves = ball.leaves().len(); @@ -81,8 +82,7 @@ impl Group { // Serialize the ball to disk. ftlog::info!("Writing ball to {ball_path:?}"); - bincode::serialize_into(std::fs::File::create(&ball_path).map_err(|e| e.to_string())?, &ball) - .map_err(|e| e.to_string())?; + ball.write_to(&ball_path)?; // Write the ball to a CSV file. let csv_path = path_manager.ball_csv_path(); @@ -95,45 +95,28 @@ impl Group { let squishy_ball_path = path_manager.squishy_ball_path(); let compressed_path = path_manager.compressed_path(); - let (squishy_ball, mut compressed) = if squishy_ball_path.exists() && compressed_path.exists() { + let (squishy_ball, compressed) = if squishy_ball_path.exists() && compressed_path.exists() { ftlog::info!("Reading squishy ball from {squishy_ball_path:?}"); - let squishy_ball = - bincode::deserialize_from(std::fs::File::open(&squishy_ball_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let squishy_ball = Sb::read_from(&squishy_ball_path)?; ftlog::info!("Reading compressed data from {compressed_path:?}"); - let codec_data = - bincode::deserialize_from(std::fs::File::open(&compressed_path).map_err(|e| e.to_string())?) - .map_err(|e| e.to_string())?; + let codec_data = Dec::par_read_from(&compressed_path)?; (squishy_ball, codec_data) } else { - ftlog::info!("Building squishy ball and permuted data."); - let (mut squishy_ball, perm_data) = { - let (off_ball, data) = OffBall::par_from_ball_tree(ball.clone(), uncompressed.clone()); - let mut squishy_ball = SquishyBall::par_adapt_tree_iterative(off_ball, None); + ftlog::info!("Building squishy ball and compressed data."); + let (squishy_ball, codec_data) = { + let (squishy_ball, data) = SquishyBall::par_from_ball_tree(ball.clone(), uncompressed.clone(), &metric); - // Set the costs of the squishy ball and write it to a CSV file. - ftlog::info!("Setting costs and writing pre-trim CSV."); - squishy_ball.par_set_costs(&data); - squishy_ball.write_to_csv(&path_manager.pre_trim_csv_path())?; + // Write it to a CSV file. + squishy_ball.write_to_csv(&path_manager.squishy_ball_csv_path())?; - (squishy_ball, data) + (squishy_ball, data.with_metadata(uncompressed.metadata())?) }; - // Trim the squishy ball and write it to a CSV file. - ftlog::info!("Trimming squishy ball and writing post-trim CSV."); - squishy_ball.trim(4); - squishy_ball.write_to_csv(&path_manager.squishy_csv_path())?; - let num_leaves = squishy_ball.leaves().len(); ftlog::info!("Built squishy ball with {num_leaves} leaves."); - // Create the compressed dataset and set its metadata. - ftlog::info!("Building compressed dataset."); - let codec_data = CodecData::from_compressible(&perm_data, &squishy_ball) - .with_metadata(uncompressed.metadata().to_vec())?; - let num_bytes = codec_data .leaf_bytes() .iter() @@ -141,27 +124,15 @@ impl Group { .sum::(); ftlog::info!("Built compressed dataset with {num_bytes} leaf bytes."); - // Change the metadata type of the squishy ball to match the compressed dataset. - let squishy_ball = squishy_ball.with_metadata_type::(); - // Serialize the squishy ball and the compressed dataset to disk. ftlog::info!("Writing squishy ball to {squishy_ball_path:?}"); - bincode::serialize_into( - std::fs::File::create(&squishy_ball_path).map_err(|e| e.to_string())?, - &squishy_ball, - ) - .map_err(|e| e.to_string())?; + squishy_ball.write_to(&squishy_ball_path)?; ftlog::info!("Writing compressed data to {compressed_path:?}"); - bincode::serialize_into( - std::fs::File::create(&compressed_path).map_err(|e| e.to_string())?, - &codec_data, - ) - .map_err(|e| e.to_string())?; + codec_data.par_write_to(&compressed_path)?; (squishy_ball, codec_data) }; - compressed.set_metric(I::metric()); Ok(Self { path_manager, @@ -174,81 +145,77 @@ impl Group { }) } + fn bench_search(&self, num_queries: usize, alg_a: &Aco, alg_b: &Adec) -> Result, String> + where + Aco: ParSearchAlgorithm, + Adec: ParSearchAlgorithm, + { + let metric = &Levenshtein; + let name = alg_a.name(); + + let queries = &self.queries[..num_queries]; + ftlog::info!("Running benchmarks for compressive search on {num_queries} queries with {name}"); + + let uncompressed_start = std::time::Instant::now(); + let uncompressed_hits = alg_a.par_batch_search(&self.uncompressed, metric, &self.ball, queries); + let uncompressed_time = uncompressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); + ftlog::info!( + "Algorithm {name} took {uncompressed_time:.3e} seconds per query uncompressed time on {}", + self.path_manager.name() + ); + + let compressed_start = std::time::Instant::now(); + let compressed_hits = alg_b.par_batch_search(&self.compressed, metric, &self.squishy_ball, queries); + let compressed_time = compressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); + ftlog::info!( + "Algorithm {name} took {compressed_time:.3e} seconds per query compressed time on {}", + self.path_manager.name() + ); + + self.verify_hits(uncompressed_hits, compressed_hits)?; + + let slowdown = compressed_time / uncompressed_time; + Ok(vec![ + format!("uncompressed: {uncompressed_time:.4e}"), + format!("uncompressed_throughput: {:.4e}", 1.0 / uncompressed_time), + format!("compressed: {compressed_time:.4e}"), + format!("compressed_throughput: {:.4e}", 1.0 / compressed_time), + format!("slowdown: {slowdown:.4}"), + ]) + } + /// Run benchmarks for compressive search on the dataset. /// /// # Errors /// /// - If there is an error writing the times to disk. pub fn bench_compressive_search(&self, num_queries: usize) -> Result<(), String> { - let radius = 5; - let k = 10; - let algorithms = [ - Algorithm::RnnLinear(radius), - Algorithm::RnnClustered(radius), - Algorithm::KnnLinear(k), - Algorithm::KnnRepeatedRnn(k, 2), - Algorithm::KnnBreadthFirst(k), - Algorithm::KnnDepthFirst(k), - ]; - - let num_queries = num_queries.min(self.queries.len()); - let queries = &self.queries[..num_queries]; - ftlog::info!( - "Running benchmarks for compressive search on {num_queries} queries with {} algorithms", - algorithms.len() - ); + let num_queries = Ord::min(num_queries, self.queries.len()); + ftlog::info!("Running benchmarks for compressive search on {num_queries} queries."); let mut times = HashMap::new(); - for (i, alg) in algorithms.iter().enumerate() { - ftlog::info!( - "Running algorithm {} ({}/{}) on {}", - alg.name(), - i + 1, - algorithms.len(), - self.path_manager.name() - ); - - let uncompressed_start = std::time::Instant::now(); - let uncompressed_hits = alg.par_batch_par_search(&self.uncompressed, &self.ball, queries); - let uncompressed_time = uncompressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); - ftlog::info!( - "Algorithm {} took {:.3e} seconds per query uncompressed time on {}", - alg.name(), - uncompressed_time, - self.path_manager.name() - ); - - let compressed_start = std::time::Instant::now(); - let compressed_hits = alg.par_batch_par_search(&self.compressed, &self.squishy_ball, queries); - let compressed_time = compressed_start.elapsed().as_secs_f32() / num_queries.as_f32(); - ftlog::info!( - "Algorithm {} took {:.3e} seconds per query compressed time on {}", - alg.name(), - compressed_time, - self.path_manager.name() - ); - - self.verify_hits(uncompressed_hits, compressed_hits)?; - - let slowdown = compressed_time / uncompressed_time; - times.insert( - alg.name(), - ( - format!("uncompressed: {uncompressed_time:.4e}"), - format!("uncompressed_throughput: {:.4e}", 1.0 / uncompressed_time), - format!("compressed: {compressed_time:.4e}"), - format!("compressed_throughput: {:.4e}", 1.0 / compressed_time), - format!("slowdown: {slowdown:.4}"), - ), - ); + for radius in [5, 10, 100] { + let times_inner = self.bench_search(num_queries, &RnnClustered(radius), &RnnClustered(radius))?; + times.insert(format!("RnnClustered({radius})"), times_inner); } - let times_path = self.path_manager.times_path(); - if times_path.exists() { - std::fs::remove_file(×_path).map_err(|e| e.to_string())?; + for k in [1, 10, 100] { + let times_inner = self.bench_search(num_queries, &KnnRepeatedRnn(k, 2), &KnnRepeatedRnn(k, 2))?; + times.insert(format!("KnnRepeatedRnn({k}, 2)"), times_inner); + + let times_inner = self.bench_search(num_queries, &KnnBreadthFirst(k), &KnnBreadthFirst(k))?; + times.insert(format!("KnnBreadthFirst({k})"), times_inner); + + let times_inner = self.bench_search(num_queries, &KnnDepthFirst(k), &KnnDepthFirst(k))?; + times.insert(format!("KnnDepthFirst({k})"), times_inner); } - serde_json::to_writer_pretty(std::fs::File::create(×_path).map_err(|e| e.to_string())?, ×) - .map_err(|e| e.to_string())?; + + ftlog::info!("Writing times to disk."); + serde_json::to_writer_pretty( + std::fs::File::create(self.path_manager.times_path()).map_err(|e| e.to_string())?, + ×, + ) + .map_err(|e| e.to_string())?; Ok(()) } diff --git a/crates/results/cakes/src/lib.rs b/crates/results/cakes/src/lib.rs new file mode 100644 index 000000000..ef341ea80 --- /dev/null +++ b/crates/results/cakes/src/lib.rs @@ -0,0 +1,19 @@ +#![deny(clippy::correctness)] +#![warn( + missing_docs, + clippy::all, + clippy::suspicious, + clippy::style, + clippy::complexity, + clippy::perf, + clippy::pedantic, + clippy::nursery, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::cast_lossless +)] +#![doc = include_str!("../README.md")] + +pub mod data; +pub mod utils; diff --git a/crates/results/cakes/src/main.rs b/crates/results/cakes/src/main.rs index 2d634e145..0ec2cb33e 100644 --- a/crates/results/cakes/src/main.rs +++ b/crates/results/cakes/src/main.rs @@ -77,6 +77,13 @@ fn main() -> Result<(), String> { let mut data = args.dataset.read(&inp_path, &out_dir)?; ftlog::info!("Finished reading dataset and queries."); + // let pool = rayon::ThreadPoolBuilder::new() + // .num_threads(1) + // .build() + // .map_err(|e| e.to_string())?; + + // pool.install(|| data.benchmark(args.num_queries))?; + data.benchmark(args.num_queries)?; ftlog::info!("Finished benchmarking."); diff --git a/crates/results/cakes/src/utils.rs b/crates/results/cakes/src/utils.rs index b61dac2e8..26ce40b9a 100644 --- a/crates/results/cakes/src/utils.rs +++ b/crates/results/cakes/src/utils.rs @@ -7,6 +7,12 @@ use ftlog::{ LevelFilter, LoggerGuard, }; +/// Configures the logger. +/// +/// # Errors +/// +/// - If a logs directory could not be located/created. +/// - If the logger could not be initialized. pub fn configure_logger(file_name: &str) -> Result<(LoggerGuard, PathBuf), String> { let root_dir = PathBuf::from(".").canonicalize().map_err(|e| e.to_string())?; let logs_dir = root_dir.join("logs"); diff --git a/crates/results/chaoda/Cargo.toml b/crates/results/chaoda/Cargo.toml index 82e2fc241..53d0867df 100644 --- a/crates/results/chaoda/Cargo.toml +++ b/crates/results/chaoda/Cargo.toml @@ -5,10 +5,10 @@ edition = "2021" [dependencies] clap = { version = "4.5.16", features = ["derive"] } -abd-clam = { workspace = true, features = ["chaoda", "ndarray-bindings"] } +abd-clam = { workspace = true, features = ["chaoda", "disk-io"] } distances = { workspace = true } rand = { workspace = true } ndarray = { workspace = true } ndarray-npy = { workspace = true } ftlog = { workspace = true } -bincode = { workspace = true } +bincode = { version = "1.3" } diff --git a/crates/results/chaoda/src/data/mod.rs b/crates/results/chaoda/src/data/mod.rs index 8181c03ab..c41cf1c5b 100644 --- a/crates/results/chaoda/src/data/mod.rs +++ b/crates/results/chaoda/src/data/mod.rs @@ -2,11 +2,11 @@ use std::path::Path; -use abd_clam::{Dataset, FlatVec, Metric}; +use abd_clam::{dataset::AssociatesMetadataMut, Dataset, FlatVec}; use ndarray::prelude::*; use ndarray_npy::ReadNpyExt; -type ChaodaDataset = FlatVec, f64, bool>; +type ChaodaDataset = FlatVec, bool>; /// The datasets used for anomaly detection. /// @@ -163,7 +163,7 @@ impl Data { } /// Read the training datasets from the paper - pub fn read_paper_train(data_dir: &Path) -> Result<[ChaodaDataset; 4], String> { + pub fn read_train_data(data_dir: &Path) -> Result<[ChaodaDataset; 4], String> { Ok([ Self::Annthyroid.read(data_dir)?, // Self::Mnist.read(data_dir)?, @@ -175,7 +175,7 @@ impl Data { } /// Read the inference datasets from the paper - pub fn read_paper_inference(data_dir: &Path) -> Result, String> { + pub fn read_infer_data(data_dir: &Path) -> Result, String> { Ok(vec![ Self::Arrhythmia.read(data_dir)?, Self::BreastW.read(data_dir)?, @@ -200,20 +200,19 @@ impl Data { /// Read all the datasets pub fn read_all(data_dir: &Path) -> Result, String> { - let mut datasets = Self::read_paper_train(data_dir)?.to_vec(); - datasets.extend(Self::read_paper_inference(data_dir)?); + let mut datasets = Self::read_train_data(data_dir)?.to_vec(); + datasets.extend(Self::read_infer_data(data_dir)?); Ok(datasets) } } fn read_xy(path: &Path, name: &str) -> Result { let labels_path = path.join(format!("{name}_labels.npy")); - let reader = std::fs::File::open(labels_path).map_err(|e| e.to_string())?; - let labels = Array1::::read_npy(reader).map_err(|e| e.to_string())?; + let reader = std::fs::File::open(labels_path).map_err(|e| format!("Could not open file: {e}"))?; + let labels = Array1::::read_npy(reader).map_err(|e| format!("Could not read labels from file {path:?}: {e}"))?; let labels = labels.mapv(|y| y == 1).to_vec(); let x_path = path.join(format!("{name}.npy")); - let fv: FlatVec, f64, usize> = FlatVec::read_npy(x_path, Metric::default())?; - - fv.with_name(name).with_metadata(labels) + let fv: FlatVec, usize> = FlatVec::read_npy(&x_path)?; + fv.with_name(name).with_metadata(&labels) } diff --git a/crates/results/chaoda/src/main.rs b/crates/results/chaoda/src/main.rs index b191ab19d..32f3ec9cb 100644 --- a/crates/results/chaoda/src/main.rs +++ b/crates/results/chaoda/src/main.rs @@ -20,8 +20,10 @@ use std::path::PathBuf; use clap::Parser; use abd_clam::{ - chaoda::{ChaodaTrainer, GraphAlgorithm, TrainableMetaMlModel}, - Ball, Cluster, Dataset, Metric, + chaoda::{GraphAlgorithm, TrainableMetaMlModel, TrainableSmc}, + dataset::AssociatesMetadata, + metric::{Euclidean, Manhattan}, + Ball, Cluster, Dataset, }; use distances::Number; @@ -86,27 +88,8 @@ fn main() -> Result<(), String> { // Set some parameters for tree building let seed = Some(42); - - let metrics = [ - Metric::new(|x: &Vec, y: &Vec| distances::vectors::euclidean(x, y), false).with_name("euclidean"), - Metric::new(|x: &Vec, y: &Vec| distances::vectors::manhattan(x, y), false).with_name("manhattan"), - // Metric::new(|x: &Vec, y: &Vec| distances::vectors::cosine(x, y), false).with_name("cosine"), - // Metric::new(|x: &Vec, y: &Vec| distances::vectors::canberra(x, y), false).with_name("canberra"), - // Metric::new(|x: &Vec, y: &Vec| distances::vectors::bray_curtis(x, y), false).with_name("bray_curtis"), - ]; - ftlog::info!("Using {} metrics...", metrics.len()); - - let mut train_datasets = data::Data::read_paper_train(&data_dir)?; - - let criteria = { - let mut criteria = Vec::new(); - for _ in 0..train_datasets.len() { - criteria.push(default_criteria::<_, _, _, 2>()); - } - criteria - .try_into() - .unwrap_or_else(|_| unreachable!("We have a criterion for each pair of metric and dataset.")) - }; + let train_datasets = data::Data::read_train_data(&data_dir)?; + let criteria = default_criteria(&train_datasets); let labels = { let mut labels = Vec::new(); for data in &train_datasets { @@ -116,13 +99,14 @@ fn main() -> Result<(), String> { .try_into() .unwrap_or_else(|_| unreachable!("We have labels for each dataset.")) }; + let depths = (args.min_depth..).step_by(5).take(5).collect::>(); ftlog::info!("Training datasets:"); for d in &train_datasets { ftlog::info!("{}", d.name()); } - let model = if use_pre_trained { + let [model_euc, model_man] = if use_pre_trained { // Load the pre-trained CHAODA model ftlog::info!("Loading pre-trained model from: {model_path:?}"); bincode::deserialize_from(std::fs::File::open(&model_path).map_err(|e| e.to_string())?) @@ -131,53 +115,81 @@ fn main() -> Result<(), String> { // Create a Chaoda trainer let meta_ml_models = TrainableMetaMlModel::default_models(); let graph_algorithms = GraphAlgorithm::default_algorithms(); - let mut model = ChaodaTrainer::new_all_pairs(metrics.clone(), meta_ml_models, graph_algorithms); - // Create the trees for use in training the model. - let trees = model.par_create_trees(&mut train_datasets, &criteria, seed); - - // Train the model - let trained_model = model.par_train(&mut train_datasets, &trees, &labels, args.min_depth, args.num_epochs)?; - ftlog::info!("Completed training for {} epochs", args.num_epochs); + let mut smc_euc = TrainableSmc::new(&meta_ml_models, &graph_algorithms); + let trees = smc_euc.par_create_trees(&train_datasets, &criteria, &Euclidean, seed); + let trained_euc = smc_euc.par_train( + &train_datasets, + &Euclidean, + &trees, + &labels, + args.min_depth, + &depths, + args.num_epochs, + )?; + ftlog::info!( + "Completed training for {} epochs with Euclidean metric", + args.num_epochs + ); + + let mut smc_man = TrainableSmc::new(&meta_ml_models, &graph_algorithms); + let trees = smc_man.par_create_trees(&train_datasets, &criteria, &Manhattan, seed); + let trained_man = smc_man.par_train( + &train_datasets, + &Manhattan, + &trees, + &labels, + args.min_depth, + &depths, + args.num_epochs, + )?; + ftlog::info!( + "Completed training for {} epochs with Manhattan metric", + args.num_epochs + ); + + let trained_models = [trained_euc, trained_man]; // Save the trained model ftlog::info!("Saving model to: {model_path:?}"); bincode::serialize_into( std::fs::File::create(&model_path).map_err(|e| e.to_string())?, - &trained_model, + &trained_models, ) .map_err(|e| e.to_string())?; ftlog::info!("Model saved to: {model_path:?}"); - trained_model - }; - - let model = if use_pre_trained { - let mut model = model; - model.set_metrics(metrics); - model - } else { - model + trained_models }; // Print the ROC scores for all datasets - for mut data in data::Data::read_all(&data_dir)? { + for data in data::Data::read_all(&data_dir)? { ftlog::info!("Starting evaluation for: {}", data.name()); let labels = data.metadata().to_vec(); - let criteria = default_criteria::<_, _, _, 2>(); - let roc_score = model.par_evaluate(&mut data, &criteria, &labels, seed, args.min_depth); - ftlog::info!("Dataset: {} ROC-AUC score: {roc_score:.6}", data.name()); + let criteria = |c: &Ball<_>| c.cardinality() > 10; + + let roc_score = model_euc.par_evaluate(&data, &labels, &Euclidean, &criteria, seed, args.min_depth, 0.02); + ftlog::info!( + "Dataset: {}, Metric: Euclidean ROC-AUC score: {roc_score:.6}", + data.name() + ); + + let roc_score = model_man.par_evaluate(&data, &labels, &Manhattan, &criteria, seed, args.min_depth, 0.02); + ftlog::info!( + "Dataset: {}, Metric: Manhattan ROC-AUC score: {roc_score:.6}", + data.name() + ); } Ok(()) } /// Returns the default partitioning criteria, repeated `N` times. -fn default_criteria, const N: usize>() -> [impl Fn(&Ball) -> bool; N] { +fn default_criteria(_: &[A; N]) -> [impl Fn(&Ball) -> bool; N] { let mut criteria = Vec::with_capacity(N); for _ in 0..N { - criteria.push(|c: &Ball<_, _, _>| c.cardinality() > 10); + criteria.push(|c: &Ball<_>| c.cardinality() > 10); } criteria .try_into() diff --git a/crates/results/msa/Cargo.toml b/crates/results/msa/Cargo.toml new file mode 100644 index 000000000..527aa0336 --- /dev/null +++ b/crates/results/msa/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "results-msa" +version = "0.1.0" +edition = "2021" + +[dependencies] +clap = { workspace = true } +ftlog = { workspace = true } +results-cakes = { path = "../cakes" } +abd-clam = { workspace = true, features = ["msa", "disk-io"] } +distances = { workspace = true } +bio = { workspace = true } +rayon = { workspace = true } +stringzilla = { workspace = true } +bench-utils = { path = "../../../benches/utils" } diff --git a/crates/results/msa/README.md b/crates/results/msa/README.md new file mode 100644 index 000000000..362b245e6 --- /dev/null +++ b/crates/results/msa/README.md @@ -0,0 +1,32 @@ +# Results for MSA with CLAM + +## Usage + +Run the following command to see the usage information: + +```shell +cargo run -rp results-msa -- --help +``` + +If you want to run the MSA on all a sequences in a fasta file, you can use the following command: + +```shell +cargo run -rp results-msa -- \ + -i ../data/string-data/greengenes/gg_13_5.fasta \ + -o ../data/string-data/greengenes/msa-results \ + -m extended-iupac +``` + +If you want to run the MSA on a subset of the sequences in a fasta file, you can use use the optional `-n` flag to specify the number of sequences to use: + +```shell +cargo run -rp results-msa -- \ + -i ../data/string-data/greengenes/gg_13_5.fasta \ + -o ../data/string-data/greengenes/msa-results \ + -m extended-iupac \ + -n 1000 +``` + +## Citation + +TODO... diff --git a/crates/results/msa/src/data/mod.rs b/crates/results/msa/src/data/mod.rs new file mode 100644 index 000000000..6788f7a7d --- /dev/null +++ b/crates/results/msa/src/data/mod.rs @@ -0,0 +1,29 @@ +//! Datasets for use in MSA experiments. + +use std::path::Path; + +use abd_clam::{dataset::AssociatesMetadata, msa::MSA}; + +mod raw; + +use distances::Number; +pub use raw::FastaFile; + +/// Write a CLAM `Dataset` to a FASTA file. +pub fn write_fasta, T: Number>(data: &MSA, path: P) -> Result<(), String> { + let path = path.as_ref(); + let file = std::fs::File::create(path).map_err(|e| format!("Failed to create file {path:?}: {e}"))?; + let mut writer = bio::io::fasta::Writer::new(file); + + let metadata = data.metadata(); + let sequences = data.data().items(); + + for (id, seq) in metadata.iter().zip(sequences.iter()) { + let record = bio::io::fasta::Record::with_attrs(id, None, seq.as_str().as_bytes()); + writer + .write_record(&record) + .map_err(|e| format!("Failed to write record: {e}"))?; + } + + Ok(()) +} diff --git a/crates/results/msa/src/data/raw.rs b/crates/results/msa/src/data/raw.rs new file mode 100644 index 000000000..84a211461 --- /dev/null +++ b/crates/results/msa/src/data/raw.rs @@ -0,0 +1,164 @@ +//! Reading data from various sources. + +use std::path::{Path, PathBuf}; + +use abd_clam::{ + dataset::{AssociatesMetadataMut, DatasetIO}, + msa::{Aligner, Sequence}, + Dataset, FlatVec, +}; +use distances::Number; +use results_cakes::data::fasta; + +/// We exclusively use Fasta files for the raw data. +pub struct FastaFile { + raw_path: PathBuf, + out_dir: PathBuf, + name: String, +} + +impl FastaFile { + /// Creates a new `FastaFile` from the given path. + pub fn new>(raw_path: P, out_dir: Option

) -> Result { + let raw_path: PathBuf = raw_path.into().canonicalize().map_err(|e| e.to_string())?; + let name = raw_path + .file_stem() + .map(|s| s.to_string_lossy().to_string()) + .ok_or("No file name found")?; + + let out_dir = if let Some(out_dir) = out_dir { + out_dir.into() + } else { + ftlog::info!("No output directory specified. Using default."); + let mut out_dir = raw_path + .parent() + .ok_or("No parent directory of `inp_dir`")? + .to_path_buf(); + out_dir.push(format!("{name}_results")); + if !out_dir.exists() { + std::fs::create_dir(&out_dir).map_err(|e| e.to_string())?; + } + out_dir + } + .canonicalize() + .map_err(|e| e.to_string())?; + + if !raw_path.exists() { + return Err(format!("Path does not exist: {raw_path:?}")); + } + + if !raw_path.is_file() { + return Err(format!("Path is not a file: {raw_path:?}")); + } + + if !out_dir.exists() { + return Err(format!("Output directory does not exist: {out_dir:?}")); + } + + if !out_dir.is_dir() { + return Err(format!("Output directory is not a directory: {out_dir:?}")); + } + + Ok(Self { + raw_path, + out_dir, + name, + }) + } + + /// Returns the name of the fasta file without the extension. + pub fn name(&self) -> &str { + self.name.as_str() + } + + /// Returns the path to the raw fasta file. + pub fn raw_path(&self) -> &Path { + &self.raw_path + } + + /// Returns the path to the output directory. + pub fn out_dir(&self) -> &Path { + &self.out_dir + } + + /// Reads the dataset from the given path. + /// + /// # Arguments + /// + /// * `num_samples` - The number of samples to read from the dataset. If `None`, all samples are read. + /// + /// # Returns + /// + /// The dataset and queries, if they were read successfully. + /// + /// # Errors + /// + /// * If the dataset is not readable. + /// * If the dataset is not in the expected format. + #[allow(clippy::too_many_lines)] + pub fn read<'a, T: Number>( + &self, + num_samples: Option, + remove_gaps: bool, + aligner: &'a Aligner, + ) -> Result, String>, String> { + let data_path = { + let mut data_path = self.out_dir.clone(); + data_path.push(self.data_name(num_samples)); + data_path + }; + + let mut data = if data_path.exists() { + ftlog::info!("Reading data from {data_path:?}"); + let data = FlatVec::::read_from(&data_path)?; + let transformer = |s: String| Sequence::new(s, Some(aligner)); + data.transform_items(transformer) + } else { + let (data, min_len, max_len) = { + let ([mut data, _], _) = fasta::read(&self.raw_path, 0, remove_gaps)?; + if let Some(num_samples) = num_samples { + data.truncate(num_samples); + } + let (min_len, max_len) = data + .iter() + .fold((usize::MAX, usize::MIN), |(min_len, max_len), (_, s)| { + let len = s.len(); + (Ord::min(min_len, len), Ord::max(max_len, len)) + }); + (data, min_len, max_len) + }; + + ftlog::info!("Kept {} sequences with lengths in [{min_len}, {max_len}].", data.len()); + + let (metadata, data): (Vec<_>, Vec<_>) = data.into_iter().unzip(); + + let data = abd_clam::FlatVec::new(data)? + .with_metadata(&metadata)? + .with_dim_lower_bound(min_len) + .with_dim_upper_bound(max_len) + .transform_items(|s| Sequence::new(s, Some(aligner))); + + ftlog::info!("Writing data to {data_path:?}"); + let writable_data = data.clone().transform_items(|s| s.seq().to_string()); + writable_data.write_to(&data_path)?; + + data + }; + + let name = num_samples.map_or_else( + || self.name().to_string(), + |num_samples| format!("{}-{}", self.name(), num_samples), + ); + data = data.with_name(&name); + + Ok(data) + } + + /// Returns the name of the file containing the uncompressed data as a serialized `FlatVec`. + fn data_name(&self, num_samples: Option) -> String { + num_samples.map_or_else( + || format!("{}.flat_data", self.name()), + |num_samples| format!("{}-{num_samples}.flat_data", self.name()), + ) + } +} diff --git a/crates/results/msa/src/main.rs b/crates/results/msa/src/main.rs new file mode 100644 index 000000000..a854a00bc --- /dev/null +++ b/crates/results/msa/src/main.rs @@ -0,0 +1,210 @@ +#![deny(clippy::correctness)] +#![warn( + missing_docs, + clippy::all, + clippy::suspicious, + clippy::style, + clippy::complexity, + clippy::perf, + clippy::pedantic, + clippy::nursery, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::cast_lossless +)] +#![doc = include_str!("../README.md")] + +use core::ops::Neg; + +use std::path::PathBuf; + +use abd_clam::{metric::Levenshtein, msa, Cluster, Dataset}; +use clap::Parser; +use distances::Number; +use results_cakes::{data::PathManager, utils::configure_logger}; + +mod data; +mod steps; + +/// Reproducible results for the MSA paper. +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + /// Path to the input fasta file. + #[arg(short('i'), long)] + inp_path: PathBuf, + + /// Whether the original fasta file was pre-aligned by the provider. + #[arg(short('p'), long)] + pre_aligned: bool, + + /// Whether to use a balanced partition. + #[arg(short('b'), long)] + balanced: bool, + + /// Optional cost of opening a gap. + #[arg(short('g'), long)] + gap_open: Option, + + /// The number of samples to use for the dataset. + #[arg(short('n'), long)] + num_samples: Option, + + /// The cost matrix to use for the alignment. + #[arg(short('m'), long)] + cost_matrix: CostMatrix, + + /// Path to the output directory. + #[arg(short('o'), long)] + out_dir: Option, +} + +/// The cost matrix to use for the alignment. +#[derive(clap::ValueEnum, Debug, Clone)] +#[allow(non_camel_case_types, clippy::doc_markdown)] +#[non_exhaustive] +pub enum CostMatrix { + /// The default matrix. + #[clap(name = "default")] + Default, + /// Default but with affine gap penalties. Gap opening is 10 and ext is 1. + #[clap(name = "default-affine")] + DefaultAffine, + /// Extended IUPAC matrix. + #[clap(name = "extended-iupac")] + ExtendedIupac, + /// Blosum62 matrix. + #[clap(name = "blosum62")] + Blosum62, +} + +impl CostMatrix { + /// Get the cost matrix. + #[must_use] + pub fn cost_matrix>(&self, gap_open: Option) -> msa::CostMatrix { + match self { + Self::Default => msa::CostMatrix::default(), + Self::DefaultAffine => msa::CostMatrix::default_affine(gap_open), + Self::ExtendedIupac => msa::CostMatrix::extended_iupac(gap_open), + Self::Blosum62 => msa::CostMatrix::blosum62(gap_open), + } + } +} + +#[allow(clippy::similar_names)] +fn main() -> Result<(), String> { + let args = Args::parse(); + ftlog::info!("{args:?}"); + + let fasta_file = data::FastaFile::new(args.inp_path, args.out_dir)?; + + let log_name = format!("msa-{}", fasta_file.name()); + // We need the `_guard` in scope to ensure proper logging. + let (_guard, log_path) = configure_logger(&log_name)?; + println!("Log file: {log_path:?}"); + + let cost_matrix = args.cost_matrix.cost_matrix::(args.gap_open); + let aligner = msa::Aligner::new(&cost_matrix, b'-'); + + ftlog::info!("Input file: {:?}", fasta_file.raw_path()); + ftlog::info!("Output directory: {:?}", fasta_file.out_dir()); + + let data = fasta_file.read::(args.num_samples, args.pre_aligned, &aligner)?; + ftlog::info!( + "Finished reading original dataset: length range = {:?}", + data.dimensionality_hint() + ); + let path_manager = PathManager::new(data.name(), fasta_file.out_dir()); + + let metric = Levenshtein; + + let msa_fasta_path = path_manager.msa_fasta_path(); + if !msa_fasta_path.exists() { + let msa_ball_path = path_manager.msa_ball_path(); + let msa_data_path = path_manager.msa_data_path(); + + let (off_ball, perm_data) = if msa_ball_path.exists() && msa_data_path.exists() { + // Read the Offset Ball and the dataset. + steps::read_permuted_ball(&msa_ball_path, &msa_data_path, &aligner)? + } else { + let ball_path = path_manager.ball_path(); + let ball = if ball_path.exists() { + // Read the Ball. + steps::read_ball(&ball_path)? + } else { + // Build the Ball. + if args.balanced { + steps::build_balanced_ball(&data, &metric, &ball_path, &path_manager.ball_csv_path())? + } else { + steps::build_ball(&data, &metric, &ball_path, &path_manager.ball_csv_path())? + } + }; + ftlog::info!("Ball has {} leaves.", ball.leaves().len()); + + // Build the Offset Ball and the dataset. + steps::build_perm_ball(ball, data, &metric, &msa_ball_path, &msa_data_path)? + }; + ftlog::info!("Offset Ball has {} leaves.", off_ball.leaves().len()); + + // Build the MSA. + steps::build_aligned(&args.cost_matrix, args.gap_open, &off_ball, &perm_data, &msa_fasta_path)?; + ftlog::info!("Finished building MSA."); + }; + + // Read the aligned sequences and load the data. + ftlog::info!("Reading aligned sequences from: {msa_fasta_path:?}"); + let msa_data = steps::read_aligned(&msa_fasta_path, &aligner)?; + + ftlog::info!( + "Finished reading {} aligned sequences with width = {:?}.", + msa_data.cardinality(), + msa_data.dimensionality_hint() + ); + + // Compute the quality metrics. + + let gap_char = b'-'; + let gap_penalty = 1; + let mismatch_penalty = 1; + let gap_open_penalty = 10; + let gap_ext_penalty = 1; + + let ps_quality = msa_data.par_scoring_pairwise_subsample(gap_char, gap_penalty, mismatch_penalty); + ftlog::info!("Pairwise scoring metric estimate: {ps_quality}"); + + // let ps_quality = msa_data.par_scoring_pairwise(gap_char, gap_penalty, mismatch_penalty); + // ftlog::info!("Pairwise scoring metric: {ps_quality}"); + + let wps_quality = + msa_data.par_weighted_scoring_pairwise_subsample(gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty); + ftlog::info!("Weighted pairwise scoring metric estimate: {wps_quality}"); + + // let wps_quality = msa_data.par_weighted_scoring_pairwise(gap_char, gap_open_penalty, gap_ext_penalty, mismatch_penalty); + // ftlog::info!("Weighted pairwise scoring metric: {wps_quality}"); + + let (avg_p, max_p) = msa_data.par_p_distance_stats_subsample(gap_char); + ftlog::info!("Pairwise distance stats estimate: avg = {avg_p:.4}, max = {max_p:.4}"); + + // let (avg_p, max_p) = msa_data.par_p_distance_stats(gap_char); + // ftlog::info!("Pairwise distance stats: avg = {avg_p}, max = {max_p}"); + + let dd_quality = msa_data.par_distance_distortion_subsample(gap_char); + ftlog::info!("Distance distortion metric estimate: {dd_quality}"); + + // let dd_quality = msa_data.par_distance_distortion(gap_char); + // ftlog::info!("Distance distortion metric: {dd_quality}"); + + ftlog::info!("Finished scoring row-wise."); + + ftlog::info!("Converting to column-major format."); + let col_msa_data = msa_data.par_change_major(); + ftlog::info!("Finished converting to column-major format."); + + let cs_quality = col_msa_data.par_scoring_columns(gap_char, gap_penalty, mismatch_penalty); + ftlog::info!("Column scoring metric: {cs_quality}"); + + ftlog::info!("Finished scoring column-wise."); + + Ok(()) +} diff --git a/crates/results/msa/src/steps.rs b/crates/results/msa/src/steps.rs new file mode 100644 index 000000000..848974697 --- /dev/null +++ b/crates/results/msa/src/steps.rs @@ -0,0 +1,179 @@ +//! Steps for the MSA pipeline. + +use std::path::Path; + +use abd_clam::{ + cakes::PermutedBall, + cluster::{adapter::ParBallAdapter, BalancedBall, ClusterIO, Csv, ParPartition}, + dataset::{AssociatesMetadata, AssociatesMetadataMut, DatasetIO}, + metric::ParMetric, + msa::{self, Aligner, Sequence}, + Ball, Cluster, Dataset, FlatVec, +}; + +type B = Ball; +type Pb = PermutedBall>; + +/// Build the aligned datasets. +pub fn build_aligned>( + matrix: &crate::CostMatrix, + gap_open: Option, + perm_ball: &Pb, + data: &FlatVec, String>, + out_path: P, +) -> Result<(), String> { + ftlog::info!("Setting up aligner..."); + let cost_matrix = matrix.cost_matrix::(gap_open); + let aligner = msa::Aligner::new(&cost_matrix, b'-'); + + ftlog::info!("Aligning sequences..."); + // let builder = msa::Columnar::::new(&aligner).par_with_binary_tree(perm_ball, data); + let builder = msa::Columnar::::new(&aligner).par_with_tree(perm_ball, data); + + ftlog::info!("Extracting aligned sequences..."); + let msa = builder.to_flat_vec_rows().with_metadata(data.metadata())?; + let transformer = |s: Vec| s.into_iter().map(|c| c as char).collect::(); + let msa = msa.transform_items(transformer); + + ftlog::info!("Finished aligning {} sequences.", builder.len()); + let data = msa::MSA::new(&aligner, msa)?; + + let path = out_path.as_ref(); + ftlog::info!("Writing MSA to {path:?}"); + crate::data::write_fasta(&data, path)?; + + Ok(()) +} + +/// Read the aligned fasta file. +pub fn read_aligned>(path: &P, aligner: &Aligner) -> Result, String> { + ftlog::info!("Reading aligned sequences from {:?}", path.as_ref()); + + let ([aligned_sequences, _], [width, _]) = results_cakes::data::fasta::read(path, 0, false)?; + let (metadata, aligned_sequences): (Vec<_>, Vec<_>) = aligned_sequences.into_iter().unzip(); + + let data = FlatVec::new(aligned_sequences)? + .with_dim_lower_bound(width) + .with_dim_upper_bound(width) + .with_metadata(&metadata)?; + + msa::MSA::new(aligner, data) +} + +/// Build the `PermutedBall` and the permuted dataset. +#[allow(clippy::type_complexity)] +pub fn build_perm_ball<'a, P: AsRef, M: ParMetric, i32>>( + ball: B, + data: FlatVec, String>, + metric: &M, + ball_path: &P, + data_path: &P, +) -> Result<(Pb, FlatVec, String>), String> { + ftlog::info!("Building PermutedBall and permuted dataset."); + let (ball, data) = PermutedBall::par_from_ball_tree(ball, data, metric); + + ftlog::info!("Writing PermutedBall to {:?}", ball_path.as_ref()); + ball.write_to(ball_path)?; + + ftlog::info!("Writing PermutedData to {:?}", data_path.as_ref()); + let transformer = |seq: Sequence<'a, i32>| seq.seq().to_string(); + let writable_data = data.clone().transform_items(transformer); + writable_data.write_to(data_path)?; + + Ok((ball, data)) +} + +/// Read the `PermutedBall` and the permuted dataset from disk. +#[allow(clippy::type_complexity)] +pub fn read_permuted_ball<'a, P: AsRef>( + ball_path: &P, + data_path: &P, + aligner: &'a Aligner, +) -> Result<(Pb, FlatVec, String>), String> { + ftlog::info!("Reading PermutedBall from {:?}", ball_path.as_ref()); + let ball = Pb::read_from(ball_path)?; + + ftlog::info!("Reading PermutedData from {:?}", data_path.as_ref()); + let data = FlatVec::::read_from(data_path)?; + let transformer = |s: String| Sequence::new(s, Some(aligner)); + let data = data.transform_items(transformer); + + Ok((ball, data)) +} + +/// Build the Ball and the dataset. +pub fn build_ball<'a, P: AsRef, M: ParMetric, i32>>( + data: &FlatVec, String>, + metric: &M, + ball_path: &P, + csv_path: &P, +) -> Result, String> { + // Create the ball from scratch. + ftlog::info!("Building ball."); + let mut depth = 0; + let seed = Some(42); + + let indices = (0..data.cardinality()).collect::>(); + let mut ball = Ball::par_new(data, metric, &indices, 0, seed) + .unwrap_or_else(|e| unreachable!("We ensured that indices is non-empty: {e}")); + let depth_delta = abd_clam::utils::max_recursion_depth(); + + let criteria = |c: &Ball<_>| c.depth() < 1; + ball.par_partition(data, metric, &criteria, seed); + + while ball.leaves().into_iter().any(|c| !c.is_singleton()) { + depth += depth_delta; + let criteria = |c: &Ball<_>| c.depth() < depth; + ball.par_partition_further(data, metric, &criteria, seed); + } + + let num_leaves = ball.leaves().len(); + ftlog::info!("Built ball with {num_leaves} leaves."); + + // Serialize the ball to disk. + ftlog::info!("Writing ball to {:?}", ball_path.as_ref()); + ball.write_to(ball_path)?; + + // Write the ball to a CSV file.; + ftlog::info!("Writing ball to CSV at {:?}", csv_path.as_ref()); + ball.write_to_csv(&csv_path)?; + + Ok(ball) +} + +/// Read the Ball from disk. +pub fn read_ball>(path: &P) -> Result, String> { + ftlog::info!("Reading ball from {:?}", path.as_ref()); + let ball = Ball::read_from(path)?; + ftlog::info!("Finished reading Ball."); + + Ok(ball) +} + +/// Build the `Ball` with a balanced partition. +pub fn build_balanced_ball<'a, P: AsRef, M: ParMetric, i32>>( + data: &FlatVec, String>, + metric: &M, + ball_path: &P, + csv_path: &P, +) -> Result, String> { + // Create the ball from scratch. + ftlog::info!("Building Balanced ball."); + let seed = Some(42); + + let criteria = |c: &BalancedBall<_>| c.cardinality() > 1; + let ball = BalancedBall::par_new_tree(data, metric, &criteria, seed).into_ball(); + + let num_leaves = ball.leaves().len(); + ftlog::info!("Built BalancedBall with {num_leaves} leaves."); + + // Serialize the `BalancedBall` to disk. + ftlog::info!("Writing BalancedBall to {:?}", ball_path.as_ref()); + ball.write_to(ball_path)?; + + // Write the `BalancedBall` to a CSV file. + ftlog::info!("Writing BalancedBall to CSV at {:?}", csv_path.as_ref()); + ball.write_to_csv(&csv_path)?; + + Ok(ball) +} diff --git a/crates/results/rite-solutions/Cargo.toml b/crates/results/rite-solutions/Cargo.toml index d3e685400..f5a7c5616 100644 --- a/crates/results/rite-solutions/Cargo.toml +++ b/crates/results/rite-solutions/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] clap = { version = "4.5.16", features = ["derive"] } -abd-clam = { workspace = true, features = ["chaoda", "ndarray-bindings", "csv"] } +abd-clam = { workspace = true, features = ["chaoda", "disk-io"] } distances = { workspace = true } symagen = { workspace = true } rand = { workspace = true } diff --git a/crates/results/rite-solutions/src/data/mod.rs b/crates/results/rite-solutions/src/data/mod.rs index 88c726fc5..c46bca9d2 100644 --- a/crates/results/rite-solutions/src/data/mod.rs +++ b/crates/results/rite-solutions/src/data/mod.rs @@ -2,13 +2,12 @@ use std::path::Path; -use abd_clam::FlatVec; - mod gen_random; mod neighborhood_aware; mod vec_metric; mod wasserstein; +use abd_clam::FlatVec; pub use gen_random::gen_random; pub use neighborhood_aware::NeighborhoodAware; pub use vec_metric::VecMetric; @@ -16,15 +15,12 @@ pub use vec_metric::VecMetric; /// Read data from the given path or generate random data. pub fn read_or_generate>( path: Option

, - metric: &VecMetric, num_inliers: Option, dimensionality: Option, inlier_mean: Option, inlier_std: Option, seed: Option, -) -> Result, f32, usize>, String> { - let metric = metric.metric::(); - +) -> Result, usize>, String> { let data = if let Some(path) = path { let path = path.as_ref(); if !path.exists() { @@ -35,8 +31,8 @@ pub fn read_or_generate>( ext.map_or_else( || Err(format!("File extension not found in {path:?}")), |ext| match ext { - "csv" => FlatVec::read_csv(path, metric), - "npy" => FlatVec::read_npy(path, metric), + "csv" => FlatVec::read_csv(&path), + "npy" => FlatVec::read_npy(&path), _ => Err(format!("Unsupported file extension: {ext}. Must be `csv` or `npy`")), }, ) @@ -48,11 +44,11 @@ pub fn read_or_generate>( let std = inlier_std.ok_or("inlier_std must be provided")?; let data = gen_random(mean, std, car, dim, seed); - FlatVec::new_array(data, metric) + FlatVec::new_array(data) }?; - let dim = data.instances().first().map(Vec::len).ok_or("No instances found")?; - if data.instances().iter().any(|v| v.len() != dim) { + let dim = data.items().first().map(Vec::len).ok_or("No instances found")?; + if data.items().iter().any(|v| v.len() != dim) { return Err("Inconsistent dimensionality".to_string()); } diff --git a/crates/results/rite-solutions/src/data/neighborhood_aware.rs b/crates/results/rite-solutions/src/data/neighborhood_aware.rs index 25580d74b..1df026c01 100644 --- a/crates/results/rite-solutions/src/data/neighborhood_aware.rs +++ b/crates/results/rite-solutions/src/data/neighborhood_aware.rs @@ -1,20 +1,22 @@ //! A `Dataset` in which every point stores the distances to its `k` nearest neighbors. use abd_clam::{ + cakes::{self, ParSearchAlgorithm, ParSearchable, SearchAlgorithm, Searchable}, cluster::ParCluster, - dataset::{metric_space::ParMetricSpace, ParDataset}, - Cluster, Dataset, FlatVec, Metric, MetricSpace, Permutable, + dataset::{AssociatesMetadata, AssociatesMetadataMut, ParDataset, Permutable}, + metric::ParMetric, + Cluster, Dataset, FlatVec, Metric, }; use rayon::prelude::*; use super::wasserstein; -type Fv = FlatVec, f32, usize>; +type Fv = FlatVec, usize>; /// A `Dataset` in which every point stores the distances to its `k` nearest neighbors. #[allow(clippy::type_complexity)] pub struct NeighborhoodAware { - data: FlatVec, f32, (usize, Vec<(usize, f32)>)>, + data: FlatVec, (usize, Vec<(usize, f32)>)>, k: usize, } @@ -23,48 +25,54 @@ impl NeighborhoodAware { /// /// This will run knn-search on every point in the dataset and store the /// results in the dataset. - pub fn new, f32, Fv>>(data: &Fv, root: &C, k: usize) -> Self { - let alg = abd_clam::cakes::Algorithm::KnnLinear(k); + pub fn new, M: Metric, f32>>(data: &Fv, metric: &M, root: &C, k: usize) -> Self { + let alg = cakes::KnnLinear(k); let results = data - .instances() + .items() .iter() - .map(|query| alg.search(data, root, query)) + .map(|query| alg.search(data, metric, root, query)) .zip(data.metadata().iter()) .map(|(h, &i)| (i, h)) - .collect(); + .collect::>(); let data = data .clone() - .with_metadata(results) + .with_metadata(&results) .unwrap_or_else(|e| unreachable!("We created the correct size for neighborhood aware data: {e}")); Self { data, k } } /// Parallel version of `new`. - pub fn par_new, f32, Fv>>(data: &Fv, root: &C, k: usize) -> Self { - let alg = abd_clam::cakes::Algorithm::KnnLinear(k); + pub fn par_new, M: ParMetric, f32>>(data: &Fv, metric: &M, root: &C, k: usize) -> Self { + let alg = cakes::KnnLinear(k); let results = data - .instances() + .items() .par_iter() - .map(|query| alg.par_search(data, root, query)) + .map(|query| alg.par_search(data, metric, root, query)) .zip(data.metadata().par_iter()) .map(|(h, &i)| (i, h)) - .collect(); + .collect::>(); let data = data .clone() - .with_metadata(results) + .with_metadata(&results) .unwrap_or_else(|e| unreachable!("We created the correct size for neighborhood aware data: {e}")); Self { data, k } } /// Check if a point is an outlier. - pub fn is_outlier, f32, Self>>(&self, root: &C, query: &Vec, threshold: f32) -> bool { - let alg = abd_clam::cakes::Algorithm::KnnLinear(self.k); - - let hits = alg.search(self, root, query); + pub fn is_outlier, M: Metric, f32>>( + &self, + metric: &M, + root: &C, + query: &Vec, + threshold: f32, + ) -> bool { + let alg = cakes::KnnLinear(self.k); + + let hits = alg.search(self, metric, root, query); let neighbors_distances = hits .iter() .map(|&(i, _)| self.neighbor_distances(i)) @@ -93,17 +101,7 @@ impl NeighborhoodAware { } } -impl MetricSpace, f32> for NeighborhoodAware { - fn metric(&self) -> &Metric, f32> { - self.data.metric() - } - - fn set_metric(&mut self, metric: Metric, f32>) { - self.data.set_metric(metric); - } -} - -impl Dataset, f32> for NeighborhoodAware { +impl Dataset> for NeighborhoodAware { fn name(&self) -> &str { self.data.name() } @@ -142,6 +140,24 @@ impl Permutable for NeighborhoodAware { } } -impl ParMetricSpace, f32> for NeighborhoodAware {} +impl ParDataset> for NeighborhoodAware {} -impl ParDataset, f32> for NeighborhoodAware {} +impl, M: Metric, f32>> Searchable, f32, C, M> for NeighborhoodAware { + fn query_to_center(&self, metric: &M, query: &Vec, cluster: &C) -> f32 { + self.data.query_to_center(metric, query, cluster) + } + + fn query_to_all(&self, metric: &M, query: &Vec, cluster: &C) -> impl Iterator { + self.data.query_to_all(metric, query, cluster) + } +} + +impl, M: ParMetric, f32>> ParSearchable, f32, C, M> for NeighborhoodAware { + fn par_query_to_center(&self, metric: &M, query: &Vec, cluster: &C) -> f32 { + self.data.par_query_to_center(metric, query, cluster) + } + + fn par_query_to_all(&self, metric: &M, query: &Vec, cluster: &C) -> impl ParallelIterator { + self.data.par_query_to_all(metric, query, cluster) + } +} diff --git a/crates/results/rite-solutions/src/data/vec_metric.rs b/crates/results/rite-solutions/src/data/vec_metric.rs index 759d7a0a3..69d39c39d 100644 --- a/crates/results/rite-solutions/src/data/vec_metric.rs +++ b/crates/results/rite-solutions/src/data/vec_metric.rs @@ -1,7 +1,7 @@ //! Metrics for Vector datasets. -use abd_clam::Metric; -use distances::{number::Float, Number}; +use abd_clam::metric::{Cosine, Euclidean, Manhattan, ParMetric}; +use distances::number::Float; #[derive(clap::ValueEnum, Debug, Clone)] pub enum VecMetric { @@ -21,12 +21,11 @@ pub enum VecMetric { impl VecMetric { /// Returns the metric. #[must_use] - pub fn metric(&self) -> Metric, U> { - let distance_fn = match self { - Self::Euclidean => |x: &Vec, y: &Vec| distances::vectors::euclidean::(x, y), - Self::Manhattan => |x: &Vec, y: &Vec| U::from(distances::vectors::manhattan::(x, y)), - Self::Cosine => |x: &Vec, y: &Vec| distances::vectors::cosine::(x, y), - }; - Metric::new(distance_fn, false) + pub fn metric + Send + Sync, T: Float>(&self) -> Box> { + match self { + Self::Euclidean => Box::new(Euclidean), + Self::Manhattan => Box::new(Manhattan), + Self::Cosine => Box::new(Cosine), + } } } diff --git a/crates/results/rite-solutions/src/main.rs b/crates/results/rite-solutions/src/main.rs index 00f6789b5..edd46f0f3 100644 --- a/crates/results/rite-solutions/src/main.rs +++ b/crates/results/rite-solutions/src/main.rs @@ -17,9 +17,11 @@ use std::path::PathBuf; -use abd_clam::{partition::ParPartition, Ball, Cluster, Dataset, Partition}; +use abd_clam::{ + cluster::{ParPartition, Partition}, + Ball, Cluster, Dataset, +}; use clap::Parser; -use data::NeighborhoodAware; use rayon::prelude::*; mod data; @@ -103,7 +105,6 @@ fn main() -> Result<(), String> { let data = data::read_or_generate( args.inp_path, - &args.metric, args.num_inliers, args.dimensionality, args.inlier_mean, @@ -111,19 +112,19 @@ fn main() -> Result<(), String> { args.seed, )?; - let criteria = |c: &Ball<_, _, _>| c.cardinality() > 1; + let metric = args.metric.metric(); + let criteria = |c: &Ball<_>| c.cardinality() > 1; let root = if args.parallelize { - Ball::par_new_tree(&data, &criteria, args.seed) + Ball::par_new_tree(&data, &metric, &criteria, args.seed) } else { - Ball::new_tree(&data, &criteria, args.seed) + Ball::new_tree(&data, &metric, &criteria, args.seed) }; let data = if args.parallelize { - data::NeighborhoodAware::par_new(&data, &root, args.neighborhood_size) + data::NeighborhoodAware::par_new(&data, &metric, &root, args.neighborhood_size) } else { - data::NeighborhoodAware::new(&data, &root, args.neighborhood_size) + data::NeighborhoodAware::new(&data, &metric, &root, args.neighborhood_size) }; - let root = root.with_dataset_type::(); let dim = data.dimensionality_hint().0; let outliers = data::gen_random(args.outlier_mean, args.outlier_std, args.num_outliers, dim, args.seed); @@ -131,12 +132,12 @@ fn main() -> Result<(), String> { let results = if args.parallelize { outliers .par_iter() - .map(|outlier| data.is_outlier(&root, outlier, 0.5)) + .map(|outlier| data.is_outlier(&metric, &root, outlier, 0.5)) .collect::>() } else { outliers .iter() - .map(|outlier| data.is_outlier(&root, outlier, 0.5)) + .map(|outlier| data.is_outlier(&metric, &root, outlier, 0.5)) .collect() }; let results = results.into_iter().enumerate().collect::>(); diff --git a/pypi/distances/.bumpversion.cfg b/pypi/distances/.bumpversion.cfg index 9620e76ee..eda9db9c0 100644 --- a/pypi/distances/.bumpversion.cfg +++ b/pypi/distances/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.3 +current_version = 1.0.4 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? diff --git a/pypi/distances/Cargo.toml b/pypi/distances/Cargo.toml index da8b27367..a913fe75b 100644 --- a/pypi/distances/Cargo.toml +++ b/pypi/distances/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "abd-distances" -version = "1.0.3" +version = "1.0.4" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/pypi/distances/README.md b/pypi/distances/README.md index 533640640..90f62bec7 100644 --- a/pypi/distances/README.md +++ b/pypi/distances/README.md @@ -1,4 +1,4 @@ -# Algorithms for Big Data: Distances (v1.0.3) +# Algorithms for Big Data: Distances (v1.0.4) This package contains algorithms for computing distances between data points. It is a thin Python wrapper around the `distances` crate, in Rust. diff --git a/pypi/distances/VERSION b/pypi/distances/VERSION index 21e8796a0..ee90284c2 100644 --- a/pypi/distances/VERSION +++ b/pypi/distances/VERSION @@ -1 +1 @@ -1.0.3 +1.0.4 diff --git a/pypi/distances/pyproject.toml b/pypi/distances/pyproject.toml index c326d34d7..af36282f6 100644 --- a/pypi/distances/pyproject.toml +++ b/pypi/distances/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "abd-distances" description = "Distance functions: A drop-in replacement for, and a super-set of the scipy.spatial.distance module." -version = "1.0.3" +version = "1.0.4" requires-python = ">=3.9" keywords = ["distance", "metric", "simd"] classifiers = [ @@ -14,7 +14,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "numpy<=2.0.0", + "numpy", ] [tool.rye] diff --git a/pypi/distances/python/abd_distances/__init__.py b/pypi/distances/python/abd_distances/__init__.py index 112030a21..517e7d9fd 100644 --- a/pypi/distances/python/abd_distances/__init__.py +++ b/pypi/distances/python/abd_distances/__init__.py @@ -12,4 +12,4 @@ "vectors", ] -__version__ = "1.0.3" +__version__ = "1.0.4" diff --git a/pypi/distances/src/lib.rs b/pypi/distances/src/lib.rs index 8483d66ae..deb035e91 100644 --- a/pypi/distances/src/lib.rs +++ b/pypi/distances/src/lib.rs @@ -9,9 +9,9 @@ use pyo3::prelude::*; /// The `abd-distances` module implemented in Rust. #[pymodule] -fn abd_distances(py: Python, m: &PyModule) -> PyResult<()> { - simd::register(py, m)?; - strings::register(py, m)?; - vectors::register(py, m)?; +fn abd_distances(m: &Bound<'_, PyModule>) -> PyResult<()> { + simd::register(m)?; + strings::register(m)?; + vectors::register(m)?; Ok(()) } diff --git a/pypi/distances/src/simd.rs b/pypi/distances/src/simd.rs index 2e8d85018..34b7843f0 100644 --- a/pypi/distances/src/simd.rs +++ b/pypi/distances/src/simd.rs @@ -6,15 +6,14 @@ use pyo3::{exceptions::PyValueError, prelude::*}; use crate::utils::{Scalar, Vector1, Vector2, _cdist, _pdist}; -pub fn register(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> { - let simd_module = PyModule::new(py, "simd")?; - simd_module.add_function(wrap_pyfunction!(euclidean, simd_module)?)?; - simd_module.add_function(wrap_pyfunction!(sqeuclidean, simd_module)?)?; - simd_module.add_function(wrap_pyfunction!(cosine, simd_module)?)?; - simd_module.add_function(wrap_pyfunction!(cdist, simd_module)?)?; - simd_module.add_function(wrap_pyfunction!(pdist, simd_module)?)?; - parent_module.add_submodule(simd_module)?; - Ok(()) +pub fn register(pm: &Bound<'_, PyModule>) -> PyResult<()> { + let simd_module = PyModule::new(pm.py(), "simd")?; + simd_module.add_function(wrap_pyfunction!(euclidean, &simd_module)?)?; + simd_module.add_function(wrap_pyfunction!(sqeuclidean, &simd_module)?)?; + simd_module.add_function(wrap_pyfunction!(cosine, &simd_module)?)?; + simd_module.add_function(wrap_pyfunction!(cdist, &simd_module)?)?; + simd_module.add_function(wrap_pyfunction!(pdist, &simd_module)?)?; + pm.add_submodule(&simd_module) } macro_rules! build_fn { @@ -140,7 +139,7 @@ build_fn!(sqeuclidean, euclidean_sq_f32, euclidean_sq_f64); build_fn!(cosine, cosine_f32, cosine_f64); #[pyfunction] -fn cdist(py: Python<'_>, a: Vector2, b: Vector2, metric: &str) -> PyResult>> { +fn cdist<'py>(py: Python<'py>, a: Vector2, b: Vector2, metric: &str) -> PyResult>> { match (&a, &b) { // The types are the same (Vector2::F32(a), Vector2::F32(b)) => { @@ -236,7 +235,7 @@ fn cdist(py: Python<'_>, a: Vector2, b: Vector2, metric: &str) -> PyResult, a: Vector2, metric: &str) -> PyResult>> { +fn pdist<'py>(py: Python<'py>, a: Vector2, metric: &str) -> PyResult>> { match &a { Vector2::F32(a) => { let metric = _parse_metric_f32(metric)?; diff --git a/pypi/distances/src/strings.rs b/pypi/distances/src/strings.rs index cefcfe4b8..c138fd550 100644 --- a/pypi/distances/src/strings.rs +++ b/pypi/distances/src/strings.rs @@ -3,13 +3,12 @@ use distances::strings::{hamming, levenshtein, nw_distance}; use pyo3::prelude::*; -pub fn register(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> { - let strings_module = PyModule::new(py, "strings")?; - strings_module.add_function(wrap_pyfunction!(hamming_distance, strings_module)?)?; - strings_module.add_function(wrap_pyfunction!(levenshtein_distance, strings_module)?)?; - strings_module.add_function(wrap_pyfunction!(needleman_wunsch_distance, strings_module)?)?; - parent_module.add_submodule(strings_module)?; - Ok(()) +pub fn register(pm: &Bound<'_, PyModule>) -> PyResult<()> { + let strings_module = PyModule::new(pm.py(), "strings")?; + strings_module.add_function(wrap_pyfunction!(hamming_distance, &strings_module)?)?; + strings_module.add_function(wrap_pyfunction!(levenshtein_distance, &strings_module)?)?; + strings_module.add_function(wrap_pyfunction!(needleman_wunsch_distance, &strings_module)?)?; + pm.add_submodule(&strings_module) } /// Hamming distance for strings. diff --git a/pypi/distances/src/utils.rs b/pypi/distances/src/utils.rs index 056cd7757..183bde272 100644 --- a/pypi/distances/src/utils.rs +++ b/pypi/distances/src/utils.rs @@ -1,12 +1,14 @@ //! Helpers for the Python wrapper. +use std::convert::Infallible; + use distances::{vectors, Number}; -use ndarray::parallel::prelude::*; -use ndarray::{Array1, Array2}; +use ndarray::{parallel::prelude::*, Array1, Array2}; use numpy::{ndarray::Axis, PyArray1, PyReadonlyArray1, PyReadonlyArray2}; use pyo3::{ exceptions::{PyTypeError, PyValueError}, prelude::*, + types::PyFloat, }; pub enum Scalar { @@ -15,7 +17,7 @@ pub enum Scalar { } impl<'a> FromPyObject<'a> for Scalar { - fn extract(ob: &'a PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { if let Ok(a) = ob.extract::() { Ok(Scalar::F32(a)) } else if let Ok(a) = ob.extract::() { @@ -26,11 +28,17 @@ impl<'a> FromPyObject<'a> for Scalar { } } -impl IntoPy for Scalar { - fn into_py(self, py: Python<'_>) -> PyObject { +impl<'py> IntoPyObject<'py> for Scalar { + type Target = PyFloat; + + type Output = Bound<'py, Self::Target>; + + type Error = Infallible; + + fn into_pyobject(self, py: Python<'py>) -> Result { match self { - Scalar::F32(a) => a.into_py(py), - Scalar::F64(a) => a.into_py(py), + Scalar::F32(a) => a.into_pyobject(py), + Scalar::F64(a) => a.into_pyobject(py), } } } @@ -48,7 +56,7 @@ pub enum Vector1<'py> { I64(PyReadonlyArray1<'py, i64>), } -impl<'a> Vector1<'a> { +impl Vector1<'_> { pub fn cast(&self) -> Array1 { match self { Vector1::F32(a) => a.as_array().mapv(T::from), @@ -66,7 +74,7 @@ impl<'a> Vector1<'a> { } impl<'a> FromPyObject<'a> for Vector1<'a> { - fn extract(ob: &'a PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { if let Ok(a) = ob.extract::>() { Ok(Vector1::F32(a)) } else if let Ok(a) = ob.extract::>() { @@ -106,7 +114,7 @@ pub enum Vector2<'py> { I64(PyReadonlyArray2<'py, i64>), } -impl<'a> Vector2<'a> { +impl Vector2<'_> { pub fn cast(&self) -> Array2 { match self { Vector2::F32(a) => a.as_array().mapv(T::from), @@ -124,7 +132,7 @@ impl<'a> Vector2<'a> { } impl<'a> FromPyObject<'a> for Vector2<'a> { - fn extract(ob: &'a PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { if let Ok(a) = ob.extract::>() { Ok(Vector2::F32(a)) } else if let Ok(a) = ob.extract::>() { @@ -192,7 +200,7 @@ where .collect::>() } -pub fn _pdist(py: Python<'_>, a: ndarray::ArrayView2, metric: F) -> Py> +pub fn _pdist<'py, T, U, F>(py: Python<'py>, a: ndarray::ArrayView2, metric: F) -> Bound<'py, PyArray1> where T: Number + numpy::Element, U: Number + numpy::Element, @@ -209,5 +217,5 @@ where .map(move |col| metric(row.as_slice().unwrap(), col.as_slice().unwrap())) }) .collect::>(); - PyArray1::from_vec(py, result).to_owned() + PyArray1::from_vec(py, result) } diff --git a/pypi/distances/src/vectors.rs b/pypi/distances/src/vectors.rs index 36777fe19..5f78e7f57 100644 --- a/pypi/distances/src/vectors.rs +++ b/pypi/distances/src/vectors.rs @@ -8,20 +8,19 @@ use crate::utils::Scalar; use super::utils::{parse_metric, Vector1, Vector2, _cdist, _chebyshev, _manhattan, _pdist}; -pub fn register(py: Python<'_>, pm: &PyModule) -> PyResult<()> { - let m = PyModule::new(py, "vectors")?; - m.add_function(wrap_pyfunction!(braycurtis, m)?)?; - m.add_function(wrap_pyfunction!(canberra, m)?)?; - m.add_function(wrap_pyfunction!(chebyshev, m)?)?; - m.add_function(wrap_pyfunction!(euclidean, m)?)?; - m.add_function(wrap_pyfunction!(sqeuclidean, m)?)?; - m.add_function(wrap_pyfunction!(manhattan, m)?)?; - m.add_function(wrap_pyfunction!(minkowski, m)?)?; - m.add_function(wrap_pyfunction!(cosine, m)?)?; - m.add_function(wrap_pyfunction!(cdist, m)?)?; - m.add_function(wrap_pyfunction!(pdist, m)?)?; - pm.add_submodule(m)?; - Ok(()) +pub fn register(pm: &Bound<'_, PyModule>) -> PyResult<()> { + let m = PyModule::new(pm.py(), "vectors")?; + m.add_function(wrap_pyfunction!(braycurtis, &m)?)?; + m.add_function(wrap_pyfunction!(canberra, &m)?)?; + m.add_function(wrap_pyfunction!(chebyshev, &m)?)?; + m.add_function(wrap_pyfunction!(euclidean, &m)?)?; + m.add_function(wrap_pyfunction!(sqeuclidean, &m)?)?; + m.add_function(wrap_pyfunction!(manhattan, &m)?)?; + m.add_function(wrap_pyfunction!(minkowski, &m)?)?; + m.add_function(wrap_pyfunction!(cosine, &m)?)?; + m.add_function(wrap_pyfunction!(cdist, &m)?)?; + m.add_function(wrap_pyfunction!(pdist, &m)?)?; + pm.add_submodule(&m) } macro_rules! build_fn { @@ -216,7 +215,14 @@ fn minkowski(a: Vector1, b: Vector1, p: i32) -> PyResult { /// Compute the pairwise distances between two collections of vectors. #[pyfunction] -fn cdist(py: Python<'_>, a: Vector2, b: Vector2, metric: &str, p: Option) -> PyResult>> { +#[pyo3(signature = (a, b, metric, p=None))] +fn cdist<'py>( + py: Python<'py>, + a: Vector2, + b: Vector2, + metric: &str, + p: Option, +) -> PyResult>> { match p { Some(p) => { if metric.to_lowercase() != "minkowski" { @@ -392,7 +398,8 @@ fn cdist(py: Python<'_>, a: Vector2, b: Vector2, metric: &str, p: Option) - /// Compute the pairwise distances between all vectors in a collection. #[pyfunction] -fn pdist(py: Python<'_>, a: Vector2, metric: &str, p: Option) -> PyResult>> { +#[pyo3(signature = (a, metric, p=None))] +fn pdist<'py>(py: Python<'py>, a: Vector2, metric: &str, p: Option) -> PyResult>> { match p { Some(p) => { if metric.to_lowercase() != "minkowski" { diff --git a/pypi/results/cakes/.github/workflows/CI.yml b/pypi/results/cakes/.github/workflows/CI.yml deleted file mode 100644 index 1bae4be43..000000000 --- a/pypi/results/cakes/.github/workflows/CI.yml +++ /dev/null @@ -1,120 +0,0 @@ -# This file is autogenerated by maturin v1.4.0 -# To update, run -# -# maturin generate-ci github -# -name: CI - -on: - push: - branches: - - main - - master - tags: - - '*' - pull_request: - workflow_dispatch: - -permissions: - contents: read - -jobs: - linux: - runs-on: ubuntu-latest - strategy: - matrix: - target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - manylinux: auto - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - windows: - runs-on: windows-latest - strategy: - matrix: - target: [x64, x86] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: ${{ matrix.target }} - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - macos: - runs-on: macos-latest - strategy: - matrix: - target: [x86_64, aarch64] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Build sdist - uses: PyO3/maturin-action@v1 - with: - command: sdist - args: --out dist - - name: Upload sdist - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - release: - name: Release - runs-on: ubuntu-latest - if: "startsWith(github.ref, 'refs/tags/')" - needs: [linux, windows, macos, sdist] - steps: - - uses: actions/download-artifact@v3 - with: - name: wheels - - name: Publish to PyPI - uses: PyO3/maturin-action@v1 - env: - MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} - with: - command: upload - args: --non-interactive --skip-existing * diff --git a/pypi/results/cakes/.gitignore b/pypi/results/cakes/.gitignore deleted file mode 100644 index c8f044299..000000000 --- a/pypi/results/cakes/.gitignore +++ /dev/null @@ -1,72 +0,0 @@ -/target - -# Byte-compiled / optimized / DLL files -__pycache__/ -.pytest_cache/ -*.py[cod] - -# C extensions -*.so - -# Distribution / packaging -.Python -.venv/ -env/ -bin/ -build/ -develop-eggs/ -dist/ -eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -include/ -man/ -venv/ -*.egg-info/ -.installed.cfg -*.egg - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt -pip-selfcheck.json - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.cache -nosetests.xml -coverage.xml - -# Translations -*.mo - -# Mr Developer -.mr.developer.cfg -.project -.pydevproject - -# Rope -.ropeproject - -# Django stuff: -*.log -*.pot - -.DS_Store - -# Sphinx documentation -docs/_build/ - -# PyCharm -.idea/ - -# VSCode -.vscode/ - -# Pyenv -.python-version diff --git a/pypi/results/cakes/Cargo.toml b/pypi/results/cakes/Cargo.toml deleted file mode 100644 index 4497198a4..000000000 --- a/pypi/results/cakes/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "py-cakes" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[lib] -name = "py_cakes" -crate-type = ["cdylib"] - -[dependencies] -pyo3 = { workspace = true } diff --git a/pypi/results/cakes/pyproject.toml b/pypi/results/cakes/pyproject.toml deleted file mode 100644 index 4d766c3a2..000000000 --- a/pypi/results/cakes/pyproject.toml +++ /dev/null @@ -1,32 +0,0 @@ -[build-system] -requires = ["maturin>=1.4,<2.0"] -build-backend = "maturin" - -[project] -name = "py-cakes" -version = "0.1.0" -description = "A Python library for benchmarks with CAKES." -requires-python = ">=3.9" -classifiers = [ - "Programming Language :: Rust", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", -] -dependencies = [ - "numpy<=2.0.0", - "typer>=0.12.4,<0.13", - "pandas>=2.2.2,<3.0", -] - -[tool.rye] -dev-dependencies = [ - "seaborn>=0.13.2", - "editdistance>=0.8.1", - "scipy>=1.13.1", - "tqdm>=4.66.4", -] - -[tool.maturin] -python-source = "python" -features = ["pyo3/extension-module"] -profile = "release" diff --git a/pypi/results/cakes/python/py_cakes/__init__.py b/pypi/results/cakes/python/py_cakes/__init__.py deleted file mode 100644 index e639e685d..000000000 --- a/pypi/results/cakes/python/py_cakes/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Wrangling the results for CAKES and PanCAKES.""" - -from . import tables -from . import wrangling_logs - -__all__ = ["tables", "wrangling_logs"] diff --git a/pypi/results/cakes/python/py_cakes/__main__.py b/pypi/results/cakes/python/py_cakes/__main__.py deleted file mode 100644 index 4b9c4e9e9..000000000 --- a/pypi/results/cakes/python/py_cakes/__main__.py +++ /dev/null @@ -1,52 +0,0 @@ -"""CLI for the package.""" - -import logging -import pathlib - -import typer - -from py_cakes import tables -from py_cakes.wrangling_logs import wrangle_logs - -logger = logging.getLogger("py_cakes") -logger.setLevel(logging.INFO) -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) - - -app = typer.Typer() - - -@app.command() -def main( - pre_trim_path: pathlib.Path = typer.Option( # noqa: B008 - ..., - help="Path to the file to analyze.", - exists=True, - file_okay=True, - dir_okay=False, - readable=True, - resolve_path=True, - ), - post_trim_path: pathlib.Path = typer.Option( # noqa: B008 - ..., - help="Path to the file to analyze.", - exists=True, - file_okay=True, - dir_okay=False, - readable=True, - resolve_path=True, - ), -) -> None: - """Main entry point.""" - if "logs" in str(pre_trim_path): - wrangle_logs(pre_trim_path) - - if "csv" in str(pre_trim_path) and "csv" in str(post_trim_path): - tables.draw_plots(pre_trim_path, post_trim_path) - - -if __name__ == "__main__": - app() diff --git a/pypi/results/cakes/python/py_cakes/tables.py b/pypi/results/cakes/python/py_cakes/tables.py deleted file mode 100644 index b2cf1bb88..000000000 --- a/pypi/results/cakes/python/py_cakes/tables.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Reading the tables from Rust and making plots.""" - -import pathlib - -import matplotlib.pyplot as plt -import numpy -import pandas - -BALL_TYPES = { - "depth": numpy.uint32, - "cardinality": numpy.uint64, - "radius": numpy.uint32, - "lfd": numpy.float32, - "arg_center": numpy.uint64, - "arg_radial": numpy.uint64, -} - -SQUISHY_BALL_TYPES = { - **BALL_TYPES, - "offset": numpy.uint64, - "unitary_cost": numpy.uint32, - "recursive_cost": numpy.uint32, -} - - -def draw_plots( - pre_trim_path: pathlib.Path, - post_trim_path: pathlib.Path, -) -> None: - """Read the ball table from the csv file.""" - # Read the pre-trim and post-trim dataframes from the csv files. - pre_trim_df = pandas.read_csv(pre_trim_path, dtype=SQUISHY_BALL_TYPES) - post_trim_df = pandas.read_csv(post_trim_path, dtype=SQUISHY_BALL_TYPES) - - # Drop all rows where "recursive_cost" is not positive. - pre_trim_df = pre_trim_df[pre_trim_df["recursive_cost"] > 0] - post_trim_df = post_trim_df[post_trim_df["recursive_cost"] > 0] - - # Drop all rows where "cardinality" is too small. - min_cardinality = 1 - pre_trim_df = pre_trim_df[pre_trim_df["cardinality"] >= min_cardinality] - post_trim_df = post_trim_df[post_trim_df["cardinality"] >= min_cardinality] - - # Create a new column called "ratio" that is the ratio of "recursive_cost" to "unitary_cost" - pre_trim_df["ratio"] = pre_trim_df["recursive_cost"] / pre_trim_df["unitary_cost"] - post_trim_df["ratio"] = post_trim_df["recursive_cost"] / post_trim_df["unitary_cost"] - - # Calculate the maximum values of some columns. - max_ratio = numpy.ceil(max(pre_trim_df["ratio"].max(), post_trim_df["ratio"].max())) # noqa: F841 - max_lfd = max(pre_trim_df["lfd"].max(), post_trim_df["lfd"].max()) - max_depth = max(pre_trim_df["depth"].max(), post_trim_df["depth"].max()) - max_radius = max(pre_trim_df["radius"].max(), post_trim_df["radius"].max()) # noqa: F841 - - dfs = { - "pre_trim": pre_trim_df, - "post_trim": post_trim_df, - } - - cmap = "cool" - for name, df in dfs.items(): - # Make a scatter plot with "depth" on the x-axis, "ratio" on the y-axis, - # and "lfd" as the color. - df["color"], mean, std = normalized_color_scale(df["lfd"]) - ax = df.plot.scatter( - x="depth", - y="lfd", - s=0.2, - c="ratio", - cmap=cmap, - vmin=0, - vmax=numpy.ceil(max_lfd), - ) - # Set the minimum and maximum values of the y-axis. - ax.set_xlim(0, max_depth) - ax.set_ylim(1, numpy.ceil(max_lfd)) - - # Set the title of the plot to be the name of the dataframe. - title = f"Recursive / Unitary ratio for {pre_trim_path.stem}" - ax.set_title(title) - - # Save the plot to a file with the name of the dataframe. - plot_path = pre_trim_path.parent / f"{pre_trim_path.stem}-{name}-ratio.png" - plt.savefig(plot_path, dpi=300) - - # Close the plots. - plt.close("all") - - -def normalized_color_scale(values: numpy.ndarray) -> tuple[numpy.ndarray, float, float]: - """Apply Gaussian normalization to the values and return the result.""" - # Calculate the mean and standard deviation of the values. - mean = values.mean() - std = values.std() - - # Apply Gaussian normalization to the values. - return (values - mean) / std, mean, std diff --git a/pypi/results/cakes/python/py_cakes/wrangling_logs.py b/pypi/results/cakes/python/py_cakes/wrangling_logs.py deleted file mode 100644 index cd5747dc7..000000000 --- a/pypi/results/cakes/python/py_cakes/wrangling_logs.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Exploring the logs of long Clustering runs.""" - -import logging -import pathlib -import re - -logger = logging.getLogger(__name__) - - -def count_clusters(file_path: pathlib.Path) -> list[tuple[bool, int, int]]: - """Count the number of lines in a file that contain information about clusters.""" - pattern = re.compile( - r"^(?PStarting|Finished) `par_partition` of a cluster at depth (?P\d+), with (?P\d+) instances\.$", # noqa: E501 - ) - - cluster_counts = [] - - with file_path.open("r") as file: - for line in file: - if match := pattern.match(line.strip()): - status = match.group("status") == "Finished" - depth = int(match.group("depth")) - cardinality = int(match.group("cardinality")) - cluster_counts.append((status, depth, cardinality)) - - msg = f"Found {len(cluster_counts)} clusters in {file_path}." - logger.info(msg) - - return cluster_counts - - -def clusters_by_depth( - clusters: list[tuple[bool, int, int]], -) -> list[tuple[int, tuple[tuple[int, int], tuple[int, int]]]]: - """Count the number of clusters by depth.""" - depth_counts: dict[int, tuple[tuple[int, int], tuple[int, int]]] = {} - - for status, depth, cardinality in clusters: - (s_freq, s_count), (f_freq, f_count) = depth_counts.get(depth, ((0, 0), (0, 0))) - if status: - f_freq += 1 - f_count += cardinality - else: - s_freq += 1 - s_count += cardinality - depth_counts[depth] = (s_freq, s_count), (f_freq, f_count) - - return sorted(depth_counts.items()) - - -def wrangle_logs(log_path: pathlib.Path) -> None: - """Wrangle the logs of long Clustering runs.""" - msg = f"Analyzing {log_path}..." - logger.info(msg) - - clusters = count_clusters(log_path) - progress = clusters_by_depth(clusters) - - gg_car = 989_002 - for depth, ((s_freq, s_card), (f_freq, f_card)) in progress: - if depth % 256 < 50: - lines = [ - "", - f"Depth {depth:4d}:", - f"Clusters: Started {s_freq:7d}, finished {f_freq:7d}. {100 * f_freq / s_freq:3.2f}%).", # noqa: E501 - f"Instances: Started {s_card:7d}, finished {f_card:7d}. {100 * f_card / s_card:3.2f}% of started, {100 * f_card / gg_car:3.2f}% of total.", # noqa: E501 - ] - msg = "\n".join(lines) - logger.info(msg) - - msg = f"Built (or building) tree with {len(progress)} depth." - logger.info(msg) diff --git a/pypi/results/cakes/python/tests/__init__.py b/pypi/results/cakes/python/tests/__init__.py deleted file mode 100644 index 005c965f7..000000000 --- a/pypi/results/cakes/python/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for the package.""" diff --git a/pypi/results/cakes/python/tests/test_all.py b/pypi/results/cakes/python/tests/test_all.py deleted file mode 100644 index a10761bfa..000000000 --- a/pypi/results/cakes/python/tests/test_all.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Tests for the package.""" - - -def test_hello(): - pass diff --git a/pypi/results/cakes/src/lib.rs b/pypi/results/cakes/src/lib.rs deleted file mode 100644 index 53c177853..000000000 --- a/pypi/results/cakes/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -use pyo3::prelude::*; - -/// Formats the sum of two numbers as string. -#[pyfunction] -fn sum_as_string(a: usize, b: usize) -> PyResult { - Ok((a + b).to_string()) -} - -/// A Python module implemented in Rust. -#[pymodule] -fn py_cakes(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; - Ok(()) -} diff --git a/pyproject.toml b/pyproject.toml index 0f66bc343..878710022 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ { name = "Tom", email = "info@tomhoward.codes" }, ] dependencies = [ - "numpy<=2.0.0", + "numpy", ] readme = "README.md" requires-python = ">= 3.9" @@ -28,5 +28,5 @@ dev-dependencies = [ [tool.rye.workspace] members = [ "pypi/distances", - "pypi/results/cakes", + "benches/py-cakes", ] diff --git a/requirements-dev.lock b/requirements-dev.lock index 91b2f3337..db73cc1b3 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -4,51 +4,52 @@ # last locked with the following flags: # pre: false # features: [] -# all-features: true +# all-features: false # with-sources: false -e file:pypi/distances --e file:pypi/results/cakes +-e file:benches/py-cakes bump2version==1.0.1 cfgv==3.4.0 # via pre-commit -click==8.1.7 +click==8.1.8 # via typer -contourpy==1.2.0 +contourpy==1.3.0 # via matplotlib -coverage==7.4.3 +coverage==7.6.8 # via coverage # via pytest-cov cycler==0.12.1 # via matplotlib -distlib==0.3.8 +distlib==0.3.9 # via virtualenv editdistance==0.8.1 -exceptiongroup==1.2.0 +exceptiongroup==1.2.2 # via pytest execnet==2.1.1 # via pytest-xdist -filelock==3.15.4 +filelock==3.16.1 # via virtualenv -fonttools==4.49.0 +fonttools==4.55.0 # via matplotlib -identify==2.6.0 +identify==2.6.3 # via pre-commit -importlib-resources==6.1.3 +importlib-resources==6.4.5 # via matplotlib iniconfig==2.0.0 # via pytest -kiwisolver==1.4.5 +kiwisolver==1.4.7 # via matplotlib markdown-it-py==3.0.0 # via rich -matplotlib==3.8.3 +matplotlib==3.9.4 + # via py-cakes # via seaborn mdurl==0.1.2 # via markdown-it-py nodeenv==1.9.1 # via pre-commit -numpy==1.26.4 +numpy==2.0.2 # via abd-distances # via contourpy # via matplotlib @@ -56,41 +57,41 @@ numpy==1.26.4 # via py-cakes # via scipy # via seaborn -packaging==23.2 +packaging==24.2 # via matplotlib # via pytest # via pytest-sugar -pandas==2.2.2 +pandas==2.2.3 # via py-cakes # via seaborn -pillow==10.2.0 +pillow==11.0.0 # via matplotlib -platformdirs==4.2.2 +platformdirs==4.3.6 # via virtualenv pluggy==1.5.0 # via pytest -pre-commit==3.7.1 -pygments==2.17.2 +pre-commit==4.0.1 +pygments==2.18.0 # via rich -pyinstrument==4.6.2 +pyinstrument==5.0.0 # via richbench -pyparsing==3.1.2 +pyparsing==3.2.0 # via matplotlib -pytest==8.2.2 +pytest==8.3.3 # via pytest-cov # via pytest-sugar # via pytest-xdist -pytest-cov==5.0.0 +pytest-cov==6.0.0 pytest-sugar==1.0.0 pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 # via matplotlib # via pandas -pytz==2024.1 +pytz==2024.2 # via pandas -pyyaml==6.0.1 +pyyaml==6.0.2 # via pre-commit -rich==13.7.1 +rich==13.9.4 # via richbench # via typer richbench==1.0.3 @@ -100,19 +101,20 @@ shellingham==1.5.4 # via typer six==1.16.0 # via python-dateutil -termcolor==2.4.0 +termcolor==2.5.0 # via pytest-sugar -tomli==2.0.1 +tomli==2.2.1 # via coverage # via pytest -tqdm==4.66.4 -typer==0.12.4 +tqdm==4.67.1 +typer==0.15.1 # via py-cakes typing-extensions==4.12.2 + # via rich # via typer -tzdata==2024.1 +tzdata==2024.2 # via pandas -virtualenv==20.26.3 +virtualenv==20.28.0 # via pre-commit -zipp==3.17.0 +zipp==3.21.0 # via importlib-resources diff --git a/requirements.lock b/requirements.lock index 166a0c137..edf2b9e8d 100644 --- a/requirements.lock +++ b/requirements.lock @@ -4,38 +4,62 @@ # last locked with the following flags: # pre: false # features: [] -# all-features: true +# all-features: false # with-sources: false -e file:pypi/distances --e file:pypi/results/cakes -click==8.1.7 +-e file:benches/py-cakes +click==8.1.8 # via typer +contourpy==1.3.0 + # via matplotlib +cycler==0.12.1 + # via matplotlib +fonttools==4.55.3 + # via matplotlib +importlib-resources==6.4.5 + # via matplotlib +kiwisolver==1.4.7 + # via matplotlib markdown-it-py==3.0.0 # via rich +matplotlib==3.9.4 + # via py-cakes mdurl==0.1.2 # via markdown-it-py -numpy==2.0.0 +numpy==2.0.2 # via abd-distances + # via contourpy + # via matplotlib # via pandas # via py-cakes -pandas==2.2.2 +packaging==24.2 + # via matplotlib +pandas==2.2.3 # via py-cakes +pillow==11.0.0 + # via matplotlib pygments==2.18.0 # via rich +pyparsing==3.2.0 + # via matplotlib python-dateutil==2.9.0.post0 + # via matplotlib # via pandas -pytz==2024.1 +pytz==2024.2 # via pandas -rich==13.7.1 +rich==13.9.4 # via typer shellingham==1.5.4 # via typer -six==1.16.0 +six==1.17.0 # via python-dateutil -typer==0.12.4 +typer==0.15.1 # via py-cakes typing-extensions==4.12.2 + # via rich # via typer -tzdata==2024.1 +tzdata==2024.2 # via pandas +zipp==3.21.0 + # via importlib-resources diff --git a/ruff.toml b/ruff.toml index e81299046..b18a04bff 100644 --- a/ruff.toml +++ b/ruff.toml @@ -4,11 +4,11 @@ line-length = 100 [lint] select = [ "ALL" ] ignore = [ - "ANN101", # Missing type annotation for self in method - "ANN102", # Missing type annotation for cls in classmethod "PLR2004", # Use of magic value in comparison "ANN401", # Dynamically typed expressions are disallowed "ICN001", # {name} should be imported as {asname} # I like my numpy imports + "COM812", + "ISC001", ] unfixable = ["B"] # Avoid trying to fix flake8-bugbear violations.