Skip to content

Commit

Permalink
wip: fixing MSA
Browse files Browse the repository at this point in the history
  • Loading branch information
nishaq503 committed Jan 20, 2025
1 parent 75cf0c0 commit 21146d3
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,58 +11,48 @@ use crate::{cakes::PermutedBall, cluster::ParCluster, dataset::ParDataset, Clust

use super::super::Aligner;

/// A multiple sequence alignment (MSA) builder.
pub struct Columnar<T: Number> {
/// The Needleman-Wunsch aligner.
aligner: Aligner<T>,
/// The columns of the partial MSA.
columns: Vec<Vec<u8>>,
}
/// The columns of a partial MSA.
pub struct Columns(Vec<Vec<u8>>, u8);

impl<T: Number> Index<usize> for Columnar<T> {
impl Index<usize> for Columns {
type Output = Vec<u8>;

fn index(&self, index: usize) -> &Self::Output {
&self.columns[index]
&self.0[index]
}
}

impl<T: Number> Columnar<T> {
impl Columns {
/// Create a new MSA builder.
///
/// # Arguments
///
/// * `gap` - The character to use for the gap.
#[must_use]
pub fn new(aligner: &Aligner<T>) -> Self {
Self {
aligner: aligner.clone(),
columns: Vec::new(),
}
}

/// Get the gap character.
#[must_use]
pub const fn gap(&self) -> u8 {
self.aligner.gap()
pub const fn new(gap: u8) -> Self {
Self(Vec::new(), gap)
}

/// Add a binary tree of `Cluster`s to the MSA.
#[must_use]
pub fn with_binary_tree<I, D, C>(self, c: &PermutedBall<T, C>, data: &D) -> Self
pub fn with_binary_tree<I, T, D, C>(self, c: &PermutedBall<T, C>, data: &D, aligner: &Aligner<T>) -> Self
where
I: AsRef<[u8]>,
T: Number,
D: Dataset<I>,
C: Cluster<T>,
{
if c.children().is_empty() {
self.with_cluster(c, data)
self.with_cluster(c, data, aligner)
} else {
if c.children().len() != 2 {
unreachable!("Binary tree has more than two children.");
}
let aligner = self.aligner;
let left = c.children()[0];
let right = c.children()[1];

let l_msa = Self::new(&aligner).with_binary_tree(left, data);
let r_msa = Self::new(&aligner).with_binary_tree(right, data);
let l_msa = Self::new(self.1).with_binary_tree(left, data, aligner);
let r_msa = Self::new(self.1).with_binary_tree(right, data, aligner);

let l_center = left
.iter_indices()
Expand All @@ -73,20 +63,21 @@ impl<T: Number> Columnar<T> {
.position(|i| i == right.arg_center())
.unwrap_or_else(|| unreachable!("Right center not found"));

l_msa.merge(l_center, r_msa, r_center)
l_msa.merge(l_center, r_msa, r_center, aligner)
}
}

/// Add a tree of `Cluster`s to the MSA.
#[must_use]
pub fn with_tree<I, D, C>(self, c: &PermutedBall<T, C>, data: &D) -> Self
pub fn with_tree<I, T, D, C>(self, c: &PermutedBall<T, C>, data: &D, aligner: &Aligner<T>) -> Self
where
I: AsRef<[u8]>,
T: Number,
D: Dataset<I>,
C: Cluster<T>,
{
if c.children().is_empty() {
self.with_cluster(c, data)
self.with_cluster(c, data, aligner)
} else {
let children = c.children();
let (&first, rest) = children.split_first().unwrap_or_else(|| unreachable!("No children"));
Expand All @@ -95,7 +86,7 @@ impl<T: Number> Columnar<T> {
.iter_indices()
.position(|i| i == first.arg_center())
.unwrap_or_else(|| unreachable!("First center not found"));
let first = Self::new(&self.aligner).with_tree(first, data);
let first = Self::new(self.1).with_tree(first, data, aligner);

let (_, merged) = rest
.iter()
Expand All @@ -104,10 +95,10 @@ impl<T: Number> Columnar<T> {
.iter_indices()
.position(|i| i == o.arg_center())
.unwrap_or_else(|| unreachable!("Other center not found"));
(o_center, Self::new(&self.aligner).with_tree(o, data))
(o_center, Self::new(self.1).with_tree(o, data, aligner))
})
.fold((f_center, first), |(a_center, acc), (o_center, o)| {
(a_center, acc.merge(a_center, o, o_center))
(a_center, acc.merge(a_center, o, o_center, aligner))
});

merged
Expand All @@ -121,15 +112,16 @@ impl<T: Number> Columnar<T> {
/// * `sequence` - The sequence to add.
#[must_use]
pub fn with_sequence<I: AsRef<[u8]>>(mut self, sequence: &I) -> Self {
self.columns = sequence.as_ref().iter().map(|&c| vec![c]).collect();
self.0 = sequence.as_ref().iter().map(|&c| vec![c]).collect();
self
}

/// Adds sequences from a `Cluster` to the MSA.
#[must_use]
pub fn with_cluster<I, D, C>(self, c: &C, data: &D) -> Self
pub fn with_cluster<I, T, D, C>(self, c: &C, data: &D, aligner: &Aligner<T>) -> Self
where
I: AsRef<[u8]>,
T: Number,
D: Dataset<I>,
C: Cluster<T>,
{
Expand All @@ -140,41 +132,41 @@ impl<T: Number> Columnar<T> {
);
let indices = c.indices();
let (&first, rest) = indices.split_first().unwrap_or_else(|| unreachable!("No indices"));
let first = Self::new(&self.aligner).with_sequence(data.get(first));
let first = Self::new(self.1).with_sequence(data.get(first));
rest.iter()
.map(|&i| data.get(i))
.map(|s| Self::new(&self.aligner).with_sequence(s))
.fold(first, |acc, s| acc.merge(0, s, 0))
.map(|s| Self::new(self.1).with_sequence(s))
.fold(first, |acc, s| acc.merge(0, s, 0, aligner))
}

/// The number of sequences in the MSA.
pub fn len(&self) -> usize {
self.columns.first().map_or(0, Vec::len)
self.0.first().map_or(0, Vec::len)
}

/// The number of columns in the MSA.
///
/// If the MSA is empty, this will return 0.
#[must_use]
pub fn width(&self) -> usize {
self.columns.len()
self.0.len()
}

/// Whether the MSA is empty.
pub fn is_empty(&self) -> bool {
self.columns.is_empty() || self.columns.iter().all(Vec::is_empty)
self.0.is_empty() || self.0.iter().all(Vec::is_empty)
}

/// Get the columns of the MSA.
#[must_use]
pub fn columns(&self) -> &[Vec<u8>] {
&self.columns
&self.0
}

/// Get the sequence at the given index.
#[must_use]
pub fn get_sequence(&self, index: usize) -> Vec<u8> {
self.columns.iter().map(|col| col[index]).collect()
self.0.iter().map(|col| col[index]).collect()
}

/// Get the sequence at the given index.
Expand All @@ -194,7 +186,7 @@ impl<T: Number> Columnar<T> {

/// Merge two MSAs.
#[must_use]
pub fn merge(mut self, s_center: usize, mut other: Self, o_center: usize) -> Self {
pub fn merge<T: Number>(mut self, s_center: usize, mut other: Self, o_center: usize, aligner: &Aligner<T>) -> Self {
ftlog::trace!(
"Merging MSAs with cardinalities: {} and {}, and centers {s_center} and {o_center}",
self.len(),
Expand All @@ -203,8 +195,8 @@ impl<T: Number> Columnar<T> {
let s_center = self.get_sequence(s_center);
let o_center = other.get_sequence(o_center);

let table = self.aligner.dp_table(&s_center, &o_center);
let [s_to_o, o_to_s] = self.aligner.alignment_gaps(&s_center, &o_center, &table);
let table = aligner.dp_table(&s_center, &o_center);
let [s_to_o, o_to_s] = aligner.alignment_gaps(&s_center, &o_center, &table);

for i in s_to_o {
self.add_gap(i).unwrap_or_else(|e| unreachable!("{e}"));
Expand All @@ -214,18 +206,17 @@ impl<T: Number> Columnar<T> {
other.add_gap(i).unwrap_or_else(|e| unreachable!("{e}"));
}

let aligner = self.aligner;
let columns = self
.columns
.0
.into_iter()
.zip(other.columns)
.zip(other.0)
.map(|(mut x, mut y)| {
x.append(&mut y);
x
})
.collect();

Self { aligner, columns }
Self(columns, self.1)
}

/// Add a gap column to the MSA.
Expand All @@ -239,16 +230,16 @@ impl<T: Number> Columnar<T> {
/// - If the MSA is empty.
/// - If the index is greater than the number of columns.
pub fn add_gap(&mut self, index: usize) -> Result<(), String> {
if self.columns.is_empty() {
if self.0.is_empty() {
Err("MSA is empty.".to_string())
} else if index > self.width() {
Err(format!(
"Index is greater than the width of the MSA: {index} > {}",
self.width()
))
} else {
let gap_col = vec![self.gap(); self.columns[0].len()];
self.columns.insert(index, gap_col);
let gap_col = vec![self.1; self.0[0].len()];
self.0.insert(index, gap_col);
Ok(())
}
}
Expand All @@ -275,7 +266,7 @@ impl<T: Number> Columnar<T> {
/// Extract the columns as a `FlatVec`.
#[must_use]
pub fn to_flat_vec_columns(&self) -> FlatVec<Vec<u8>, usize> {
FlatVec::new(self.columns.clone())
FlatVec::new(self.0.clone())
.unwrap_or_else(|e| unreachable!("{e}"))
.with_dim_lower_bound(self.len())
.with_dim_upper_bound(self.len())
Expand All @@ -293,28 +284,28 @@ impl<T: Number> Columnar<T> {
}
}

impl<T: Number> Columnar<T> {
impl Columns {
/// Parallel version of [`Columnar::with_binary_tree`](crate::msa::dataset::columnar::Columnar::with_binary_tree).
#[must_use]
pub fn par_with_binary_tree<I, D, C>(self, c: &PermutedBall<T, C>, data: &D) -> Self
pub fn par_with_binary_tree<I, T, D, C>(self, c: &PermutedBall<T, C>, data: &D, aligner: &Aligner<T>) -> Self
where
I: AsRef<[u8]> + Send + Sync,
T: Number,
D: ParDataset<I>,
C: ParCluster<T>,
{
if c.children().is_empty() {
self.with_cluster(c, data)
self.with_cluster(c, data, aligner)
} else {
if c.children().len() != 2 {
unreachable!("Binary tree has more than two children.");
}
let aligner = self.aligner;
let left = c.children()[0];
let right = c.children()[1];

let (l_msa, r_msa) = rayon::join(
|| Self::new(&aligner).par_with_binary_tree(left, data),
|| Self::new(&aligner).par_with_binary_tree(right, data),
|| Self::new(self.1).par_with_binary_tree(left, data, aligner),
|| Self::new(self.1).par_with_binary_tree(right, data, aligner),
);

let l_center = left
Expand All @@ -326,20 +317,21 @@ impl<T: Number> Columnar<T> {
.position(|i| i == right.arg_center())
.unwrap_or_else(|| unreachable!("Right center not found"));

l_msa.par_merge(l_center, r_msa, r_center)
l_msa.par_merge(l_center, r_msa, r_center, aligner)
}
}

/// Parallel version of [`Columnar::with_tree`](crate::msa::dataset::columnar::Columnar::with_tree).
#[must_use]
pub fn par_with_tree<I, D, C>(self, c: &PermutedBall<T, C>, data: &D) -> Self
pub fn par_with_tree<I, T, D, C>(self, c: &PermutedBall<T, C>, data: &D, aligner: &Aligner<T>) -> Self
where
I: AsRef<[u8]> + Send + Sync,
T: Number,
D: ParDataset<I>,
C: ParCluster<T>,
{
if c.children().is_empty() {
self.with_cluster(c, data)
self.with_cluster(c, data, aligner)
} else {
let children = c.children();
let (&first, rest) = children.split_first().unwrap_or_else(|| unreachable!("No children"));
Expand All @@ -348,7 +340,7 @@ impl<T: Number> Columnar<T> {
.iter_indices()
.position(|i| i == first.arg_center())
.unwrap_or_else(|| unreachable!("First center not found"));
let first = Self::new(&self.aligner).with_tree(first, data);
let first = Self::new(self.1).with_tree(first, data, aligner);

let (_, merged) = rest
.par_iter()
Expand All @@ -357,12 +349,12 @@ impl<T: Number> Columnar<T> {
.iter_indices()
.position(|i| i == o.arg_center())
.unwrap_or_else(|| unreachable!("Other center not found"));
(o_center, Self::new(&self.aligner).with_tree(o, data))
(o_center, Self::new(self.1).with_tree(o, data, aligner))
})
.collect::<Vec<_>>()
.into_iter()
.fold((f_center, first), |(a_center, acc), (o_center, o)| {
(a_center, acc.par_merge(a_center, o, o_center))
(a_center, acc.par_merge(a_center, o, o_center, aligner))
});

merged
Expand All @@ -371,16 +363,22 @@ impl<T: Number> Columnar<T> {

/// Parallel version of [`Columnar::merge`](crate::msa::dataset::columnar::Columnar::merge).
#[must_use]
pub fn par_merge(mut self, s_center: usize, mut other: Self, o_center: usize) -> Self {
pub fn par_merge<T: Number>(
mut self,
s_center: usize,
mut other: Self,
o_center: usize,
aligner: &Aligner<T>,
) -> Self {
ftlog::trace!(
"Parallel Merging MSAs with cardinalities: {} and {}, and centers {s_center} and {o_center}",
self.len(),
other.len()
);
let s_center = self.get_sequence(s_center);
let o_center = other.get_sequence(o_center);
let table = self.aligner.dp_table(&s_center, &o_center);
let [s_to_o, o_to_s] = self.aligner.alignment_gaps(&s_center, &o_center, &table);
let table = aligner.dp_table(&s_center, &o_center);
let [s_to_o, o_to_s] = aligner.alignment_gaps(&s_center, &o_center, &table);

for i in s_to_o {
self.add_gap(i).unwrap_or_else(|e| unreachable!("{e}"));
Expand All @@ -390,18 +388,17 @@ impl<T: Number> Columnar<T> {
other.add_gap(i).unwrap_or_else(|e| unreachable!("{e}"));
}

let aligner = self.aligner;
let columns = self
.columns
.0
.into_par_iter()
.zip(other.columns)
.zip(other.0)
.map(|(mut x, mut y)| {
x.append(&mut y);
x
})
.collect();

Self { aligner, columns }
Self(columns, self.1)
}

/// Parallel version of [`Columnar::extract_msa`](crate::msa::dataset::columnar::Columnar::extract_msa).
Expand Down
Loading

0 comments on commit 21146d3

Please sign in to comment.