Skip to content

Commit

Permalink
added functionality to align reads using ssw algorithm (deepchem#4142)
Browse files Browse the repository at this point in the history
* added functionality to align reads using ssw algorithm

* mypy fixes
  • Loading branch information
KitVB authored Oct 16, 2024
1 parent d826c66 commit ec22084
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 36 deletions.
11 changes: 4 additions & 7 deletions deepchem/data/tests/test_deepvariant_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@ def test_candidate_windows(self):
candidate_windows = self.featurizer._featurize(datapoint)

# Assert the number of reads
self.assertEqual(len(candidate_windows), 15)
self.assertEqual(candidate_windows[13][0], 'chr2')
self.assertEqual(candidate_windows[13][1], 136)
self.assertEqual(candidate_windows[13][2], 137)
self.assertEqual(candidate_windows[13][3], 102)
self.assertEqual(candidate_windows[13][4], 21)
self.assertEqual(len(candidate_windows), 53)
self.assertEqual(candidate_windows[0]['span'], ('chr1', 3, 5))
self.assertEqual(candidate_windows[1]['span'], ('chr1', 9, 20))


if __name__ == "__main__":
unittest.main()
unittest.main()
235 changes: 206 additions & 29 deletions deepchem/feat/deepvariant_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def generate_pileup_and_reads(
"""
Generate pileup and reads from BAM and reference FASTA files. This
function generates pileup information and reads from the provided
BAM files and reference sequences, returning both allele counts and reads.
BAM files and reference sequences, returning both allele counts
and reads.
Parameters
----------
Expand All @@ -157,6 +158,8 @@ def generate_pileup_and_reads(
chrom = x[3] # Reference name

pileup_info = x[7] if len(x) > 7 else None
if pileup_info is None:
continue
for pileupcolumn in pileup_info:
for read_seq in pileupcolumn['reads']:
reads.append(read_seq)
Expand Down Expand Up @@ -273,7 +276,7 @@ def select_candidate_regions(

def fetchreads(self, bamfiles: List[Tuple[str, Any, int, str, int, Any, int,
Any]], chrom: str, start: int,
end: int) -> List[str]:
end: int) -> List[Tuple[str, Any, int, str, int, Any, int]]:
"""
Fetch reads from BAM files for a specific chromosome and region.
This function extracts reads from BAM files that overlap with the
Expand All @@ -295,21 +298,23 @@ def fetchreads(self, bamfiles: List[Tuple[str, Any, int, str, int, Any, int,
Returns
-------
List[str]
List of read sequences that fall within the specified region.
List[Tuple[str, Any, int, str, int, Any, int]]
List of reads that overlap with the specified chromosome
and region.
"""
reads: List[str] = []
reads: List[Tuple[str, Any, int, str, int, Any, int]] = []
for bamfile in bamfiles:
refname = bamfile[3]
refstart = bamfile[4]
refend = refstart + bamfile[2]

if refname == chrom and refstart < end and refend > start:
reads.append(bamfile[1])
reads.append(bamfile[0:7])
return reads

def build_debruijn_graph(
self, ref: str, reads: List[str], k: int
self, ref: str, reads: List[Tuple[str, Any, int, str, int, Any,
int]], k: int
) -> Tuple[Optional[Any], Optional[Dict[str, int]], Optional[Dict[int,
str]]]:
"""
Expand All @@ -323,17 +328,18 @@ def build_debruijn_graph(
ref : str
Reference sequence as a string.
reads : List[str]
reads : List[Tuple[str, Any, int, str, int, Any, int]]
List of read sequences.
k : int
Length of k-mers.
Returns
-------
Tuple[Optional[Any], Optional[Dict[str, int]], Optional[Dict[int, str]]]
Tuple[Optional[Any], Optional[Dict[str, int]],
Optional[Dict[int, str]]]
A tuple containing:
- De Bruijn graph (dgl.DGLGraph)
- De Bruijn graph (Any)
- Dictionary mapping k-mer strings to node IDs (Dict[str, int])
- Dictionary mapping node IDs to k-mer strings (Dict[int, str])
"""
Expand All @@ -349,7 +355,7 @@ def get_kmers(sequence: str, k: int):

# Count k-mers in reads
for read in reads:
for kmer in get_kmers(read, k):
for kmer in get_kmers(read[1], k):
kmer_counts[kmer] += 1

kmer_to_id: Dict[str, int] = {}
Expand Down Expand Up @@ -428,7 +434,7 @@ def candidate_haplotypes(self, G: Any, k: int,
Parameters
----------
G : dgl.DGLGraph
G : Any
The De Bruijn graph.
k : int
The k-mer length.
Expand Down Expand Up @@ -468,12 +474,164 @@ def dfs(node: int, path: List[int]) -> None:

return sorted(haplotypes)

def assign_reads_to_regions(
self, assembled_regions: List[Dict[str, Any]],
reads: List[Tuple[str, Any, int, str, int, Any, int]]
) -> List[Tuple[str, Any, int, str, int, Any, int]]:
"""
Assign reads to regions based on maximum overlap with haplotypes.
Parameters
----------
assembled_regions : List[Dict[str, Any]]
List of dictionaries, where each dictionary contains
information about a region, including its haplotypes
and reads.
reads : List[Tuple[str, Any, int, str, int, Any, int]]
List of reads.
Returns
-------
List[Tuple[str, Any, int, str, int, Any, int]]
List of reads that couldn't be assigned to any region.
"""
regions = [(0, len(ar["haplotypes"][0])) for ar in assembled_regions]
unassigned_reads: List[Tuple[str, Any, int, str, int, Any, int]] = []
for read in reads:
read_start = read[4]
read_end = read_start + read[2]
# to find maximum overlap
max_overlap = 0
max_index = None
for i, region in enumerate(regions):
region_start, region_end = map(
int, region) # Ensure regions are integers
overlap = max(
0,
min(read_end, region_end) - max(read_start, region_start))
if overlap > max_overlap:
max_overlap = overlap
max_index = i
window_i = max_index
if window_i is not None:
assembled_regions[window_i]["reads"].append(read)
else:
unassigned_reads.append(read)
return unassigned_reads

def align(self,
query_sequence: str,
ref_sequence: str,
match_score: int = 2,
mismatch_penalty: int = -1,
gap_open: int = -2,
gap_extend: int = -1) -> dict:
"""
SIMD-optimized Smith-Waterman alignment function using PyTorch.
Parameters
----------
query_sequence : str
The query sequence (e.g., a read).
ref_sequence : str
The reference sequence (e.g., a haplotype).
match_score : int
Score for matching characters.
mismatch_penalty : int
Penalty for mismatching characters.
gap_open : int
Penalty for opening a gap.
gap_extend : int
Penalty for extending a gap.
Returns
-------
dict
A dictionary containing the alignment score, end positions of the
alignment, and matrices used in computation.
"""

# Convert sequences to integer representation (A=0, C=1, G=2, T=3)
def seq_to_int(seq):
return torch.tensor([ord(c) % 4 for c in seq], dtype=torch.int32)

query_len = len(query_sequence)
ref_len = len(ref_sequence)

# Convert the sequences to integer indices for faster comparison
query_tensor = seq_to_int(query_sequence).to(torch.int32)
ref_tensor = seq_to_int(ref_sequence).to(torch.int32)

# Initialize scoring matrices: H, E (gap extension matrix)
H = torch.zeros((query_len + 1, ref_len + 1), dtype=torch.int32)
E = torch.zeros((query_len + 1, ref_len + 1), dtype=torch.int32)

# Track maximum score and coordinates of alignment endpoint
max_score = 0
end_query, end_ref = 0, 0

# Iterate over each position in the query sequence
for i in range(1, query_len + 1):
# Vectorize over all positions in the reference sequence
# Substitution score: +match_score for match,
# -mismatch_penalty for mismatch
match_mask = (query_tensor[i - 1] == ref_tensor).to(torch.int32)
sub_scores = match_mask * match_score + (
~match_mask) * mismatch_penalty

# Compute the matrix values for H and E in a SIMD-like fashion
H_diag = H[i - 1, :-1] + sub_scores
E[:, 1:] = torch.max(H[i - 1, :-1] + gap_open,
E[:, 1:] + gap_extend)

H[i, 1:] = torch.max(torch.tensor([0]), torch.max(H_diag, E[i, 1:]))

# Track the max score and its position
if H[i, 1:].max().item() > max_score:
max_score = int(H[i, 1:].max().item())
end_ref = int(torch.argmax(H[i, 1:]).item())
end_query = i

return {
'score': max_score,
'end_query': end_query,
'end_ref': end_ref,
'H': H,
'E': E
}

def fast_pass_aligner(self, assembled_region: Dict[str, Any]) -> List[Any]:
"""
Align reads to the haplotype of the assembled region using Striped Smith
Waterman algorithm.
Parameters
----------
assembled_region : Dict[str, Any]
Dictionary containing the haplotype information and reads
for a given region.
Returns
-------
List[Any]
List of alignments returned by the aligner.
"""
aligned_reads: List[Any] = []
ref_sequence = assembled_region["haplotypes"][0]
for read in assembled_region["reads"]:
query_sequence = read[1]
alignment = self.align(query_sequence, ref_sequence)
aligned_reads.append(alignment)
return aligned_reads

def process_candidate_windows(
self, candidate_regions: List[Tuple[str, int, int, int]],
bamfiles: List[Any], reference_seq_dict: Dict[str, str]
) -> List[Tuple[str, int, int, int, int, List[str]]]:
self, candidate_regions: List[Tuple[str, int, int,
int]], bamfiles: List[Any],
reference_seq_dict: Dict[str, str]) -> List[Dict[str, Any]]:
"""
Process candidate regions to generate candidate windows with haplotypes.
Process candidate regions to generate window haplotyples with
realigned reads.
Parameters
----------
Expand All @@ -483,15 +641,20 @@ def process_candidate_windows(
bamfiles : List[Any]
List of BAM file data.
reference_seq_dict : Dict[str, str]
Dictionary with chromosome names as keys and reference sequences as values.
Dictionary with chromosome names as keys and reference
sequences as values.
Returns
-------
List[Tuple[str, int, int, int, int, List[str]]]
List of candidate windows with haplotypes.
List[Dict[str, Any]]
List of dictionaries, where each dictionary represents a
candidate window and contains:
- 'span' : Tuple of (chromosome, start, end)
- 'haplotypes' : List of haplotypes (List[str])
- 'realigned_reads' : List of realigned reads (List[Any])
"""
candidate_windows: List[Tuple[str, int, int, int, int, List[str]]] = []
windows_haplotypes = []

for chrom, start, end, count in candidate_regions:
window_reads = self.fetchreads(bamfiles, chrom, start, end)
Expand Down Expand Up @@ -522,15 +685,29 @@ def process_candidate_windows(
if candidate_haplotypes_list and candidate_haplotypes_list != [
ref_sequence
]:
candidate_windows.append((chrom, start, end, count, k,
candidate_haplotypes_list))
break

if not found_graph:
candidate_windows.append(
(chrom, start, end, count, k, [ref_sequence]))
candidate_haplotypes_list = [ref_sequence]

assembled_regions = [{
"haplotypes": haplotypes,
"reads": []
} for haplotypes in candidate_haplotypes_list]
realigned_reads = self.assign_reads_to_regions(
assembled_regions, window_reads)

for assembled_region in assembled_regions:
aligned_reads = self.fast_pass_aligner(assembled_region)
realigned_reads.extend(aligned_reads)

windows_haplotypes.append({
'span': (chrom, start, end),
'haplotypes': candidate_haplotypes_list,
'realigned_reads': realigned_reads
})

return candidate_windows
return windows_haplotypes


class RealignerFeaturizer(Featurizer):
Expand Down Expand Up @@ -605,7 +782,7 @@ def _featurize(self, datapoint):
fasta_dataset = fasta_loader.create_dataset(reference_file_path)

one_hot_encoded_sequences = fasta_dataset.X
decoded_sequences: List[str] = []
decoded_sequences = []

# Convert the one-hot encoded sequences to strings
for seq in one_hot_encoded_sequences:
Expand All @@ -615,14 +792,14 @@ def _featurize(self, datapoint):
# Map the sequences to chrom names
chrom_names = ["chr1", "chr2"]

reference_seq_dict: Dict[str, str] = {
reference_seq_dict = {
chrom_names[i]: seq for i, seq in enumerate(decoded_sequences)
}

allele_counts, reads = self.realigner.generate_pileup_and_reads(
bamfiles, reference_seq_dict)
candidate_regions = self.realigner.select_candidate_regions(
allele_counts)
candidate_windows = self.realigner.process_candidate_windows(
windows_haplotypes = self.realigner.process_candidate_windows(
candidate_regions, bamfiles, reference_seq_dict)
return candidate_windows
return windows_haplotypes

0 comments on commit ec22084

Please sign in to comment.