From 2e759c961345e2d5480cd323cc20af607896fb0d Mon Sep 17 00:00:00 2001 From: Najib Ishaq Date: Sun, 12 Nov 2023 21:41:06 -0500 Subject: [PATCH] feat: using iupac codes and elaborate penalties for silva dataset --- abd-clam/src/cakes/knn/mod.rs | 4 +- cakes-results/Cargo.toml | 1 + cakes-results/src/genomic/main.rs | 246 +++++++++++++---------- cakes-results/src/genomic/metrics.rs | 76 +++++++ cakes-results/src/genomic/nucleotides.rs | 227 +++++++++++++++++++++ cakes-results/src/genomic/read_silva.rs | 23 ++- cakes-results/src/genomic/sequence.rs | 72 +++++++ 7 files changed, 536 insertions(+), 113 deletions(-) create mode 100644 cakes-results/src/genomic/metrics.rs create mode 100644 cakes-results/src/genomic/nucleotides.rs create mode 100644 cakes-results/src/genomic/sequence.rs diff --git a/abd-clam/src/cakes/knn/mod.rs b/abd-clam/src/cakes/knn/mod.rs index f4b197ec0..4f46185cb 100644 --- a/abd-clam/src/cakes/knn/mod.rs +++ b/abd-clam/src/cakes/knn/mod.rs @@ -162,10 +162,10 @@ impl Algorithm { } } - /// Returns a list of all the algorithms, excluding Linear. + /// Returns a list of all the algorithms, excluding `Linear` and `RepeatedRnn`. #[must_use] pub const fn variants<'a>() -> &'a [Self] { - &[Self::RepeatedRnn, Self::GreedySieve, Self::Sieve, Self::SieveSepCenter] + &[Self::GreedySieve, Self::Sieve, Self::SieveSepCenter] } } diff --git a/cakes-results/Cargo.toml b/cakes-results/Cargo.toml index ea3e57dd2..3a1f6df17 100644 --- a/cakes-results/Cargo.toml +++ b/cakes-results/Cargo.toml @@ -22,6 +22,7 @@ env_logger = "0.10.0" num-format = "0.4.4" csv = "1.2.2" rand = "0.8.5" +rayon = "1.8.0" [[bin]] name = "knn-results" diff --git a/cakes-results/src/genomic/main.rs b/cakes-results/src/genomic/main.rs index 39b4d5731..fba2096b6 100644 --- a/cakes-results/src/genomic/main.rs +++ b/cakes-results/src/genomic/main.rs @@ -15,10 +15,14 @@ clippy::cast_lossless )] #![allow(unused_imports)] +#![allow(clippy::cast_possible_truncation)] //! Cakes benchmarks on genomic datasets. +mod metrics; +mod nucleotides; mod read_silva; +mod sequence; use core::cmp::Ordering; use std::{ @@ -26,12 +30,15 @@ use std::{ time::Instant, }; -use abd_clam::{Cakes, Dataset, Instance, VecDataset}; +use abd_clam::{knn, rnn, Cakes, Dataset, Instance, VecDataset}; use clap::Parser; use distances::Number; use log::{debug, info, warn}; +use sequence::Sequence; use serde::{Deserialize, Serialize}; +use crate::nucleotides::Nucleotide; + fn main() -> Result<(), String> { env_logger::Builder::from_default_env() .filter_level(log::LevelFilter::Info) @@ -68,7 +75,7 @@ fn main() -> Result<(), String> { return Err(format!("Output directory {output_dir:?} does not exist.")); } - let [train, queries, train_headers, query_headers] = + let [(train, train_headers), (queries, query_headers)] = read_silva::silva_to_dataset(&unaligned_path, &headers_path, metric, is_expensive)?; info!( "Read {} data set. Cardinality: {}", @@ -163,104 +170,125 @@ fn main() -> Result<(), String> { /// # Errors /// /// * If the `output_dir` does not exist or cannot be written to. -#[allow(clippy::too_many_arguments)] +#[allow(clippy::too_many_arguments, clippy::too_many_lines)] fn measure_throughput( - cakes: &Cakes>, + cakes: &Cakes>, built: bool, build_time: f32, args: &Args, - queries: &[&String], + queries: &[&Sequence], query_headers: &[&String], train_headers: &VecDataset, stem: &str, metric_name: &str, output_dir: &Path, ) -> Result<(), String> { - let tuned_algorithm = cakes.tuned_knn_algorithm(); - let tuned_algorithm = tuned_algorithm.name(); - let train = cakes.shards()[0]; + let num_linear_queries = 10; + + // Run the linear search algorithm. + let k = args + .ks + .iter() + .max() + .copied() + .unwrap_or_else(|| unreachable!("No ks!")); + let start = Instant::now(); + let linear_results = cakes.batch_linear_knn_search(&queries[..num_linear_queries], k); + let linear_search_time = start.elapsed().as_secs_f32(); + let linear_throughput = queries.len().as_f32() / linear_search_time; + let linear_throughput = linear_throughput / num_linear_queries.as_f32(); + info!("With k = {k}, achieved linear search throughput of {linear_throughput:.2e} QPS.",); + // Perform knn-search for each value of k on all queries. for &k in &args.ks { info!("Starting knn-search with k = {k} ..."); - // Run the tuned algorithm. - let start = Instant::now(); - let results = cakes.batch_tuned_knn_search(queries, k); - let search_time = start.elapsed().as_secs_f32(); - let throughput = queries.len().as_f32() / search_time; - info!("With k = {k}, achieved throughput of {throughput:.2e} QPS."); - - // Run the linear search algorithm. - let start = Instant::now(); - let linear_results = cakes.batch_linear_knn_search(queries, k); - let linear_search_time = start.elapsed().as_secs_f32(); - let linear_throughput = queries.len().as_f32() / linear_search_time; - info!("With k = {k}, achieved linear search throughput of {linear_throughput:.2e} QPS.",); - - let (mean_recall, hits) = recall_and_hits( - results, - linear_results, - queries, - query_headers, - train_headers, - train, - ); - info!("With k = {k}, achieved mean recall of {mean_recall:.3}."); - - // Create the report. - let report = Report { - dataset: stem, - metric: metric_name, - cardinality: train.cardinality(), - built, - build_time, - shard_sizes: cakes.shard_cardinalities(), - num_queries: queries.len(), - kind: "knn", - val: k, - tuned_algorithm, - throughput, - linear_throughput, - hits, - mean_recall, - }; - - // Save the report. - report.save(output_dir)?; + // Run each algorithm. + for &algorithm in knn::Algorithm::variants() { + let start = Instant::now(); + let results = cakes.batch_knn_search(queries, k, algorithm); + let search_time = start.elapsed().as_secs_f32(); + let throughput = queries.len().as_f32() / search_time; + info!( + "With k = {k}, {} achieved throughput of {throughput:.2e} QPS.", + algorithm.name() + ); + + let (mean_recall, _) = recall_and_hits( + results, + &linear_results, + query_headers, + train_headers, + train, + ); + debug!("With k = {k}, achieved mean recall of {mean_recall:.3}."); + + // Create the report. + let report = Report { + dataset: stem, + metric: metric_name, + cardinality: train.cardinality(), + built, + build_time, + shard_sizes: cakes.shard_cardinalities(), + num_queries: queries.len(), + kind: "knn", + val: k, + algorithm: algorithm.name(), + throughput, + linear_throughput, + // hits, + // mean_recall, + }; + + // Save the report. + report.save(output_dir)?; + } } + // Run the linear search algorithm. + let radius = args + .rs + .iter() + .max() + .copied() + .unwrap_or_else(|| unreachable!("No radii!")) as u32 + * Nucleotide::gap_penalty(); + let start = Instant::now(); + let linear_results = cakes.batch_linear_rnn_search(&queries[..num_linear_queries], radius); + let linear_search_time = start.elapsed().as_secs_f32(); + let linear_throughput = queries.len().as_f32() / linear_search_time; + let linear_throughput = linear_throughput / num_linear_queries.as_f32(); + info!("Linear rnn search throughput of {linear_throughput:.2e} QPS.",); + // Perform range search for each value of r on all queries. - for &r in &args.rs { + for r in args + .rs + .iter() + .map(|&r| r as u32 * Nucleotide::gap_penalty()) + { info!("Starting range search with r = {r} ..."); - #[allow(clippy::cast_possible_truncation)] - let radius = r as u32; - // Run the tuned algorithm. let start = Instant::now(); - let results = cakes.batch_tuned_rnn_search(queries, radius); + let results = cakes.batch_rnn_search(queries, radius, rnn::Algorithm::Clustered); let search_time = start.elapsed().as_secs_f32(); let throughput = queries.len().as_f32() / search_time; - info!("With r = {r}, achieved throughput of {throughput:.2e} QPS."); - - // Run the linear search algorithm. - let start = Instant::now(); - let linear_results = cakes.batch_linear_rnn_search(queries, radius); - let linear_search_time = start.elapsed().as_secs_f32(); - let linear_throughput = queries.len().as_f32() / linear_search_time; - info!("With r = {r}, achieved linear search throughput of {linear_throughput:.2e} QPS.",); + info!( + "With r = {r}, {} achieved throughput of {throughput:.2e} QPS.", + rnn::Algorithm::Clustered.name() + ); - let (mean_recall, hits) = recall_and_hits( + let (mean_recall, _) = recall_and_hits( results, - linear_results, - queries, + &linear_results, query_headers, train_headers, train, ); - info!("With r = {r}, achieved mean recall of {mean_recall:.3}."); + debug!("With r = {r}, achieved mean recall of {mean_recall:.3}."); // Create the report. let report = Report { @@ -272,12 +300,12 @@ fn measure_throughput( shard_sizes: cakes.shard_cardinalities(), num_queries: queries.len(), kind: "rnn", - val: r, - tuned_algorithm, + val: r as usize, + algorithm: rnn::Algorithm::Clustered.name(), throughput, linear_throughput, - hits, - mean_recall, + // hits, + // mean_recall, }; // Save the report. @@ -305,19 +333,18 @@ fn measure_throughput( #[allow(clippy::type_complexity)] fn recall_and_hits( results: Vec>, - linear_results: Vec>, - queries: &[&String], + linear_results: &[Vec<(usize, u32)>], query_headers: &[&String], train_headers: &VecDataset, - train: &VecDataset, + train: &VecDataset, ) -> (f32, Vec<(String, Vec<(String, u32)>)>) { // Compute the recall of the tuned algorithm. let mean_recall = results .iter() .zip(linear_results) - .map(|(hits, linear_hits)| compute_recall(hits.clone(), linear_hits)) + .map(|(hits, linear_hits)| compute_recall(hits.clone(), linear_hits.clone())) .sum::() - / queries.len().as_f32(); + / linear_results.len().as_f32(); // Convert results to original indices. let hits = results.into_iter().map(|hits| { @@ -408,11 +435,11 @@ impl Metric { /// Return the metric function. #[allow(clippy::ptr_arg)] - fn metric(&self) -> fn(&String, &String) -> u32 { + fn metric(&self) -> fn(&Sequence, &Sequence) -> u32 { match self { - Self::Hamming => hamming, - Self::Levenshtein => levenshtein, - Self::NeedlemanWunsch => needleman_wunsch, + Self::Hamming => metrics::hamming, + Self::Levenshtein => metrics::levenshtein, + Self::NeedlemanWunsch => metrics::needleman_wunsch, } } @@ -425,23 +452,23 @@ impl Metric { } } -/// Compute the Hamming distance between two strings. -#[allow(clippy::ptr_arg)] -fn hamming(x: &String, y: &String) -> u32 { - distances::strings::hamming(x, y) -} +// /// Compute the Hamming distance between two strings. +// #[allow(clippy::ptr_arg)] +// fn hamming(x: &String, y: &String) -> u32 { +// distances::strings::hamming(x, y) +// } -/// Compute the Levenshtein distance between two strings. -#[allow(clippy::ptr_arg)] -fn levenshtein(x: &String, y: &String) -> u32 { - distances::strings::levenshtein(x, y) -} +// /// Compute the Levenshtein distance between two strings. +// #[allow(clippy::ptr_arg)] +// fn levenshtein(x: &String, y: &String) -> u32 { +// distances::strings::levenshtein(x, y) +// } -/// Compute the Needleman-Wunsch distance between two strings. -#[allow(clippy::ptr_arg)] -fn needleman_wunsch(x: &String, y: &String) -> u32 { - distances::strings::needleman_wunsch::nw_distance(x, y) -} +// /// Compute the Needleman-Wunsch distance between two strings. +// #[allow(clippy::ptr_arg)] +// fn needleman_wunsch(x: &String, y: &String) -> u32 { +// distances::strings::needleman_wunsch::nw_distance(x, y) +// } /// A report of the results of an ANN benchmark. #[derive(Debug, Serialize, Deserialize)] @@ -464,16 +491,16 @@ struct Report<'a> { kind: &'a str, /// Value of k used for knn-search or value of r used for range search. val: usize, - /// Name of the algorithm used after auto-tuning. - tuned_algorithm: &'a str, + /// Name of the algorithm used. + algorithm: &'a str, /// Throughput of the tuned algorithm. throughput: f32, /// Throughput of linear search. linear_throughput: f32, - /// Hits for each query. - hits: Vec<(String, Vec<(String, u32)>)>, - /// Mean recall of the tuned algorithm. - mean_recall: f32, + // /// Hits for each query. + // hits: Vec<(String, Vec<(String, u32)>)>, + // /// Mean recall of the tuned algorithm. + // mean_recall: f32, } impl Report<'_> { @@ -496,7 +523,10 @@ impl Report<'_> { return Err(format!("{dir:?} is not a directory.")); } - let path = dir.join(format!("{}_{}_{}.json", self.dataset, self.kind, self.val)); + let path = dir.join(format!( + "{}_{}_{}_{}.json", + self.dataset, self.kind, self.algorithm, self.val + )); let report = serde_json::to_string_pretty(&self).map_err(|e| e.to_string())?; std::fs::write(path, report).map_err(|e| e.to_string())?; Ok(()) @@ -518,7 +548,9 @@ pub fn compute_recall( mut hits: Vec<(usize, U)>, mut linear_hits: Vec<(usize, U)>, ) -> f32 { - if linear_hits.is_empty() { + // TODO: Remove the empty check on `hits` and only keep the empty check on + // `linear_hits`? + if linear_hits.is_empty() || hits.is_empty() { 1.0 } else { let (num_hits, num_linear_hits) = (hits.len(), linear_hits.len()); @@ -542,7 +574,9 @@ pub fn compute_recall( linear_hits.next(); } } - let recall = num_common.as_f32() / num_linear_hits.as_f32(); + + // TODO: divide by the number of linear hits? + let recall = num_common.as_f32() / num_hits.as_f32(); debug!("Recall: {recall:.3}, num_common: {num_common}"); recall diff --git a/cakes-results/src/genomic/metrics.rs b/cakes-results/src/genomic/metrics.rs new file mode 100644 index 000000000..4ed6ff84d --- /dev/null +++ b/cakes-results/src/genomic/metrics.rs @@ -0,0 +1,76 @@ +//! Genomic metrics for Silva-18S sequences. + +use super::{nucleotides::Nucleotide, sequence::Sequence}; + +/// Returns the Levenshtein distance between the two given sequences. +/// +/// # Arguments +/// +/// * `x` - The first sequence. +/// * `y` - The second sequence. +pub fn levenshtein(x: &Sequence, y: &Sequence) -> u32 { + if x.is_empty() { + // handle special case of 0 length + y.len() as u32 + } else if y.is_empty() { + // handle special case of 0 length + x.len() as u32 + } else if x.len() < y.len() { + // require tat a is no shorter than b + _levenshtein(y, x) + } else { + _levenshtein(x, y) + } +} + +/// Returns the Levenshtein distance between the two given sequences. +fn _levenshtein(x: &Sequence, y: &Sequence) -> u32 { + let gap_penalty = Nucleotide::gap_penalty(); + // initialize the DP table for y + let mut cur = (0..=y.len()) + .map(|i| (i as u32) * gap_penalty) + .collect::>(); + + // calculate edit distance + for (i, &c_x) in x.neucloetides().iter().enumerate() { + // get first column for this row + let mut pre = cur[0]; + cur[0] = (i as u32 + 1) * gap_penalty; + + // calculate rest of row + for (j, &c_y) in y.neucloetides().iter().enumerate() { + let tmp = cur[j + 1]; + + let del = tmp + gap_penalty; + let ins = cur[j] + gap_penalty; + let sub = pre + c_x.penalty(c_y); + cur[j + 1] = ins.min(del).min(sub); + + pre = tmp; + } + } + + cur[y.len()] +} + +/// Returns the Hamming distance between the two given sequences. +/// +/// # Arguments +/// +/// * `x` - The first sequence. +/// * `y` - The second sequence. +#[allow(unused_variables)] +pub fn hamming(x: &Sequence, y: &Sequence) -> u32 { + todo!() +} + +/// Returns the Needleman-Wunsch distance between the two given sequences. +/// +/// # Arguments +/// +/// * `x` - The first sequence. +/// * `y` - The second sequence. +#[allow(unused_variables)] +pub fn needleman_wunsch(x: &Sequence, y: &Sequence) -> u32 { + todo!() +} diff --git a/cakes-results/src/genomic/nucleotides.rs b/cakes-results/src/genomic/nucleotides.rs new file mode 100644 index 000000000..180440135 --- /dev/null +++ b/cakes-results/src/genomic/nucleotides.rs @@ -0,0 +1,227 @@ +//! Nucleotides for Silva-18S sequences. + +/// A nucleotide for Silva-18S sequences. +/// +/// These are as defined in the [IUPAC nucleotide code](https://www.bioinformatics.org/sms/iupac.html). +#[derive(Debug, Clone, Copy)] +pub enum Nucleotide { + /// Adenine + A, + /// Cytosine + C, + /// Guanine + G, + /// Thymine (or Uracil) + U, + /// A or G + R, + /// C or T + Y, + /// G or C + S, + /// A or T + W, + /// G or T + K, + /// A or C + M, + /// C or G or T + B, + /// A or G or T + D, + /// A or C or T + H, + /// A or C or G + V, + /// Any base + N, + /// Gap + Gap, +} + +impl Nucleotide { + /// Returns the nucleotide corresponding to the given character. + /// + /// # Arguments + /// + /// * `c` - The character to convert. + /// + /// # Errors + /// + /// If the given character is not a valid nucleotide. + pub fn from_char(c: char) -> Result { + match c { + 'A' => Ok(Self::A), + 'C' => Ok(Self::C), + 'G' => Ok(Self::G), + 'U' | 'T' => Ok(Self::U), + 'R' => Ok(Self::R), + 'Y' => Ok(Self::Y), + 'S' => Ok(Self::S), + 'W' => Ok(Self::W), + 'K' => Ok(Self::K), + 'M' => Ok(Self::M), + 'B' => Ok(Self::B), + 'D' => Ok(Self::D), + 'H' => Ok(Self::H), + 'V' => Ok(Self::V), + 'N' => Ok(Self::N), + '-' | '.' => Ok(Self::Gap), + _ => Err(format!("Invalid nucleotide: {c}")), + } + } + + /// Returns the character corresponding to the given nucleotide. + pub const fn to_char(self) -> char { + match self { + Self::A => 'A', + Self::C => 'C', + Self::G => 'G', + Self::U => 'U', + Self::R => 'R', + Self::Y => 'Y', + Self::S => 'S', + Self::W => 'W', + Self::K => 'K', + Self::M => 'M', + Self::B => 'B', + Self::D => 'D', + Self::H => 'H', + Self::V => 'V', + Self::N => 'N', + Self::Gap => '-', + } + } + + /// Returns the byte corresponding to the given nucleotide. + pub const fn to_byte(self) -> u8 { + self.to_char() as u8 + } + + /// Returns the nucleotide corresponding to the given byte. + pub fn from_byte(b: u8) -> Result { + Self::from_char(b as char) + } + + /// Returns the gap penalty. + pub const fn gap_penalty() -> u32 { + 180 + } + + /// Returns the penalty for aligning the given nucleotides. + /// + /// # Arguments + /// + /// * `other` - The other nucleotide. + #[allow(clippy::too_many_lines)] + pub const fn penalty(self, other: Self) -> u32 { + match self { + Self::A => match other { + Self::A => 0, + Self::R | Self::W | Self::M => 90, + Self::D | Self::H | Self::V => 120, + Self::N => 135, + _ => Self::gap_penalty(), + }, + Self::C => match other { + Self::C => 0, + Self::R | Self::S | Self::K => 90, + Self::B | Self::D | Self::H => 120, + Self::N => 135, + _ => Self::gap_penalty(), + }, + Self::G => match other { + Self::G => 0, + Self::R | Self::S | Self::K => 90, + Self::B | Self::D | Self::V => 120, + Self::N => 135, + _ => Self::gap_penalty(), + }, + Self::U => match other { + Self::U => 0, + Self::Y | Self::W | Self::K => 90, + Self::B | Self::D | Self::H => 120, + Self::N => 135, + _ => Self::gap_penalty(), + }, + Self::R => match other { + Self::A | Self::G | Self::R => 90, + Self::D | Self::V => 120, + Self::S | Self::W | Self::K | Self::M | Self::N => 135, + Self::B | Self::H => 150, + _ => Self::gap_penalty(), + }, + Self::Y => match other { + Self::C | Self::U | Self::Y => 90, + Self::B | Self::H => 120, + Self::S | Self::W | Self::K | Self::M | Self::N => 135, + Self::D | Self::V => 150, + _ => Self::gap_penalty(), + }, + Self::S => match other { + Self::C | Self::G | Self::S => 90, + Self::B | Self::V => 120, + Self::R | Self::W | Self::K | Self::M | Self::N => 135, + Self::D | Self::H => 150, + _ => Self::gap_penalty(), + }, + Self::W => match other { + Self::A | Self::U | Self::W => 90, + Self::D | Self::H => 120, + Self::R | Self::Y | Self::K | Self::M | Self::N => 135, + Self::B | Self::V => 150, + _ => Self::gap_penalty(), + }, + Self::K => match other { + Self::G | Self::U | Self::K => 90, + Self::B | Self::D => 120, + Self::R | Self::Y | Self::S | Self::W | Self::N => 135, + Self::H | Self::V => 150, + _ => Self::gap_penalty(), + }, + Self::M => match other { + Self::A | Self::C | Self::M => 90, + Self::H | Self::V => 120, + Self::R | Self::Y | Self::S | Self::W | Self::N => 135, + Self::B | Self::D => 150, + _ => Self::gap_penalty(), + }, + Self::B => match other { + Self::C | Self::G | Self::U | Self::Y | Self::S | Self::K | Self::B => 120, + Self::N => 135, + Self::D | Self::H | Self::V => 140, + Self::R | Self::W | Self::M => 150, + _ => Self::gap_penalty(), + }, + Self::D => match other { + Self::A | Self::G | Self::U | Self::R | Self::W | Self::K | Self::D => 120, + Self::N => 135, + Self::B | Self::H | Self::V => 140, + Self::Y | Self::S | Self::M => 150, + _ => Self::gap_penalty(), + }, + Self::H => match other { + Self::A | Self::C | Self::U | Self::Y | Self::W | Self::M | Self::H => 120, + Self::N => 135, + Self::B | Self::D | Self::V => 140, + Self::R | Self::S | Self::K => 150, + _ => Self::gap_penalty(), + }, + Self::V => match other { + Self::A | Self::C | Self::G | Self::R | Self::S | Self::M | Self::V => 120, + Self::N => 135, + Self::B | Self::D | Self::H => 140, + Self::Y | Self::W | Self::K => 150, + _ => Self::gap_penalty(), + }, + Self::N => match other { + Self::Gap => Self::gap_penalty(), + _ => 135, + }, + Self::Gap => match other { + Self::Gap => 0, + _ => Self::gap_penalty(), + }, + } + } +} diff --git a/cakes-results/src/genomic/read_silva.rs b/cakes-results/src/genomic/read_silva.rs index 704d0eecf..a2a469d24 100644 --- a/cakes-results/src/genomic/read_silva.rs +++ b/cakes-results/src/genomic/read_silva.rs @@ -10,6 +10,13 @@ use abd_clam::{Dataset, VecDataset}; use distances::Number; use log::info; use rand::prelude::*; +use rayon::prelude::*; + +use crate::sequence::Sequence; + +/// A pair of datasets. The first contains the sequences, the second contains +/// the headers. +type DataPair = (VecDataset, VecDataset); /// Read the Silva-18S dataset from the given path. /// @@ -41,9 +48,9 @@ use rand::prelude::*; pub fn silva_to_dataset( unaligned_path: &Path, headers_path: &Path, - metric: fn(&String, &String) -> u32, + metric: fn(&Sequence, &Sequence) -> u32, is_expensive: bool, -) -> Result<[VecDataset; 4], String> { +) -> Result<[DataPair; 2], String> { // Get the stem of the file name. let stem = unaligned_path .file_stem() @@ -100,6 +107,12 @@ pub fn silva_to_dataset( .collect::, _>>()?; info!("Read {} headers from {headers_path:?}.", headers.len()); + // Convert the lines into sequences. + let sequences = sequences + .par_iter() + .map(|s| Sequence::from_str(s)) + .collect::, _>>()?; + // join the lines and headers into a single vector of (line, header) pairs. let mut sequences = sequences.into_iter().zip(headers).collect::>(); @@ -121,7 +134,7 @@ pub fn silva_to_dataset( let train_headers = VecDataset::new( format!("{stem}-train-headers"), train_headers, - metric, + |_, _| 0, is_expensive, ); info!( @@ -134,7 +147,7 @@ pub fn silva_to_dataset( let query_headers = VecDataset::new( format!("{stem}-query-headers"), query_headers, - metric, + |_, _| 0, is_expensive, ); info!( @@ -145,5 +158,5 @@ pub fn silva_to_dataset( assert_eq!(train.cardinality(), train_headers.cardinality()); assert_eq!(queries.cardinality(), query_headers.cardinality()); - Ok([train, queries, train_headers, query_headers]) + Ok([(train, train_headers), (queries, query_headers)]) } diff --git a/cakes-results/src/genomic/sequence.rs b/cakes-results/src/genomic/sequence.rs new file mode 100644 index 000000000..783b3cf2f --- /dev/null +++ b/cakes-results/src/genomic/sequence.rs @@ -0,0 +1,72 @@ +//! Sequences of nucleotides for the Silva-18S dataset. + +use core::fmt::Display; + +use abd_clam::Instance; +use rand::distributions::uniform::UniformSampler; + +use super::nucleotides::Nucleotide; + +/// A genomic sequence is a list of nucleotides. +#[derive(Debug, Clone)] +pub struct Sequence(Vec); + +impl Sequence { + /// Parses the given string into a sequence. + pub fn from_str(s: &str) -> Result { + let seq = s + .chars() + .map(Nucleotide::from_char) + .collect::, _>>()?; + Ok(Self(seq)) + } + + /// The number of nucleotides in the sequence. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Whether the sequence is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Returns the underlying nucleotide list. + pub fn neucloetides(&self) -> &[Nucleotide] { + &self.0 + } +} + +impl Display for Sequence { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + self.0 + .iter() + .map(|&b| Nucleotide::to_char(b)) + .collect::() + ) + } +} + +impl Instance for Sequence { + fn to_bytes(&self) -> Vec { + self.0.iter().map(|&c| Nucleotide::to_byte(c)).collect() + } + + fn from_bytes(bytes: &[u8]) -> Result + where + Self: Sized, + { + let seq = bytes + .iter() + .map(|b| Nucleotide::from_byte(*b)) + .collect::, _>>()?; + Ok(Self(seq)) + } + + fn type_name() -> String { + "Sequence".to_string() + } +}