diff --git a/deepchem/data/tests/test_deepvariant_featurizer.py b/deepchem/data/tests/test_deepvariant_featurizer.py index 81c670ed9e..1e642aa61b 100644 --- a/deepchem/data/tests/test_deepvariant_featurizer.py +++ b/deepchem/data/tests/test_deepvariant_featurizer.py @@ -20,13 +20,10 @@ def test_candidate_windows(self): candidate_windows = self.featurizer._featurize(datapoint) # Assert the number of reads - self.assertEqual(len(candidate_windows), 15) - self.assertEqual(candidate_windows[13][0], 'chr2') - self.assertEqual(candidate_windows[13][1], 136) - self.assertEqual(candidate_windows[13][2], 137) - self.assertEqual(candidate_windows[13][3], 102) - self.assertEqual(candidate_windows[13][4], 21) + self.assertEqual(len(candidate_windows), 53) + self.assertEqual(candidate_windows[0]['span'], ('chr1', 3, 5)) + self.assertEqual(candidate_windows[1]['span'], ('chr1', 9, 20)) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/deepchem/feat/deepvariant_featurizer.py b/deepchem/feat/deepvariant_featurizer.py index fc751a3fb0..f86e2679e3 100644 --- a/deepchem/feat/deepvariant_featurizer.py +++ b/deepchem/feat/deepvariant_featurizer.py @@ -133,7 +133,8 @@ def generate_pileup_and_reads( """ Generate pileup and reads from BAM and reference FASTA files. This function generates pileup information and reads from the provided - BAM files and reference sequences, returning both allele counts and reads. + BAM files and reference sequences, returning both allele counts + and reads. Parameters ---------- @@ -157,6 +158,8 @@ def generate_pileup_and_reads( chrom = x[3] # Reference name pileup_info = x[7] if len(x) > 7 else None + if pileup_info is None: + continue for pileupcolumn in pileup_info: for read_seq in pileupcolumn['reads']: reads.append(read_seq) @@ -273,7 +276,7 @@ def select_candidate_regions( def fetchreads(self, bamfiles: List[Tuple[str, Any, int, str, int, Any, int, Any]], chrom: str, start: int, - end: int) -> List[str]: + end: int) -> List[Tuple[str, Any, int, str, int, Any, int]]: """ Fetch reads from BAM files for a specific chromosome and region. This function extracts reads from BAM files that overlap with the @@ -295,21 +298,23 @@ def fetchreads(self, bamfiles: List[Tuple[str, Any, int, str, int, Any, int, Returns ------- - List[str] - List of read sequences that fall within the specified region. + List[Tuple[str, Any, int, str, int, Any, int]] + List of reads that overlap with the specified chromosome + and region. """ - reads: List[str] = [] + reads: List[Tuple[str, Any, int, str, int, Any, int]] = [] for bamfile in bamfiles: refname = bamfile[3] refstart = bamfile[4] refend = refstart + bamfile[2] if refname == chrom and refstart < end and refend > start: - reads.append(bamfile[1]) + reads.append(bamfile[0:7]) return reads def build_debruijn_graph( - self, ref: str, reads: List[str], k: int + self, ref: str, reads: List[Tuple[str, Any, int, str, int, Any, + int]], k: int ) -> Tuple[Optional[Any], Optional[Dict[str, int]], Optional[Dict[int, str]]]: """ @@ -323,7 +328,7 @@ def build_debruijn_graph( ref : str Reference sequence as a string. - reads : List[str] + reads : List[Tuple[str, Any, int, str, int, Any, int]] List of read sequences. k : int Length of k-mers. @@ -331,9 +336,10 @@ def build_debruijn_graph( Returns ------- - Tuple[Optional[Any], Optional[Dict[str, int]], Optional[Dict[int, str]]] + Tuple[Optional[Any], Optional[Dict[str, int]], + Optional[Dict[int, str]]] A tuple containing: - - De Bruijn graph (dgl.DGLGraph) + - De Bruijn graph (Any) - Dictionary mapping k-mer strings to node IDs (Dict[str, int]) - Dictionary mapping node IDs to k-mer strings (Dict[int, str]) """ @@ -349,7 +355,7 @@ def get_kmers(sequence: str, k: int): # Count k-mers in reads for read in reads: - for kmer in get_kmers(read, k): + for kmer in get_kmers(read[1], k): kmer_counts[kmer] += 1 kmer_to_id: Dict[str, int] = {} @@ -428,7 +434,7 @@ def candidate_haplotypes(self, G: Any, k: int, Parameters ---------- - G : dgl.DGLGraph + G : Any The De Bruijn graph. k : int The k-mer length. @@ -468,12 +474,164 @@ def dfs(node: int, path: List[int]) -> None: return sorted(haplotypes) + def assign_reads_to_regions( + self, assembled_regions: List[Dict[str, Any]], + reads: List[Tuple[str, Any, int, str, int, Any, int]] + ) -> List[Tuple[str, Any, int, str, int, Any, int]]: + """ + Assign reads to regions based on maximum overlap with haplotypes. + + Parameters + ---------- + assembled_regions : List[Dict[str, Any]] + List of dictionaries, where each dictionary contains + information about a region, including its haplotypes + and reads. + reads : List[Tuple[str, Any, int, str, int, Any, int]] + List of reads. + + Returns + ------- + List[Tuple[str, Any, int, str, int, Any, int]] + List of reads that couldn't be assigned to any region. + + """ + regions = [(0, len(ar["haplotypes"][0])) for ar in assembled_regions] + unassigned_reads: List[Tuple[str, Any, int, str, int, Any, int]] = [] + for read in reads: + read_start = read[4] + read_end = read_start + read[2] + # to find maximum overlap + max_overlap = 0 + max_index = None + for i, region in enumerate(regions): + region_start, region_end = map( + int, region) # Ensure regions are integers + overlap = max( + 0, + min(read_end, region_end) - max(read_start, region_start)) + if overlap > max_overlap: + max_overlap = overlap + max_index = i + window_i = max_index + if window_i is not None: + assembled_regions[window_i]["reads"].append(read) + else: + unassigned_reads.append(read) + return unassigned_reads + + def align(self, + query_sequence: str, + ref_sequence: str, + match_score: int = 2, + mismatch_penalty: int = -1, + gap_open: int = -2, + gap_extend: int = -1) -> dict: + """ + SIMD-optimized Smith-Waterman alignment function using PyTorch. + + Parameters + ---------- + query_sequence : str + The query sequence (e.g., a read). + ref_sequence : str + The reference sequence (e.g., a haplotype). + match_score : int + Score for matching characters. + mismatch_penalty : int + Penalty for mismatching characters. + gap_open : int + Penalty for opening a gap. + gap_extend : int + Penalty for extending a gap. + + Returns + ------- + dict + A dictionary containing the alignment score, end positions of the + alignment, and matrices used in computation. + """ + + # Convert sequences to integer representation (A=0, C=1, G=2, T=3) + def seq_to_int(seq): + return torch.tensor([ord(c) % 4 for c in seq], dtype=torch.int32) + + query_len = len(query_sequence) + ref_len = len(ref_sequence) + + # Convert the sequences to integer indices for faster comparison + query_tensor = seq_to_int(query_sequence).to(torch.int32) + ref_tensor = seq_to_int(ref_sequence).to(torch.int32) + + # Initialize scoring matrices: H, E (gap extension matrix) + H = torch.zeros((query_len + 1, ref_len + 1), dtype=torch.int32) + E = torch.zeros((query_len + 1, ref_len + 1), dtype=torch.int32) + + # Track maximum score and coordinates of alignment endpoint + max_score = 0 + end_query, end_ref = 0, 0 + + # Iterate over each position in the query sequence + for i in range(1, query_len + 1): + # Vectorize over all positions in the reference sequence + # Substitution score: +match_score for match, + # -mismatch_penalty for mismatch + match_mask = (query_tensor[i - 1] == ref_tensor).to(torch.int32) + sub_scores = match_mask * match_score + ( + ~match_mask) * mismatch_penalty + + # Compute the matrix values for H and E in a SIMD-like fashion + H_diag = H[i - 1, :-1] + sub_scores + E[:, 1:] = torch.max(H[i - 1, :-1] + gap_open, + E[:, 1:] + gap_extend) + + H[i, 1:] = torch.max(torch.tensor([0]), torch.max(H_diag, E[i, 1:])) + + # Track the max score and its position + if H[i, 1:].max().item() > max_score: + max_score = int(H[i, 1:].max().item()) + end_ref = int(torch.argmax(H[i, 1:]).item()) + end_query = i + + return { + 'score': max_score, + 'end_query': end_query, + 'end_ref': end_ref, + 'H': H, + 'E': E + } + + def fast_pass_aligner(self, assembled_region: Dict[str, Any]) -> List[Any]: + """ + Align reads to the haplotype of the assembled region using Striped Smith + Waterman algorithm. + + Parameters + ---------- + assembled_region : Dict[str, Any] + Dictionary containing the haplotype information and reads + for a given region. + + Returns + ------- + List[Any] + List of alignments returned by the aligner. + """ + aligned_reads: List[Any] = [] + ref_sequence = assembled_region["haplotypes"][0] + for read in assembled_region["reads"]: + query_sequence = read[1] + alignment = self.align(query_sequence, ref_sequence) + aligned_reads.append(alignment) + return aligned_reads + def process_candidate_windows( - self, candidate_regions: List[Tuple[str, int, int, int]], - bamfiles: List[Any], reference_seq_dict: Dict[str, str] - ) -> List[Tuple[str, int, int, int, int, List[str]]]: + self, candidate_regions: List[Tuple[str, int, int, + int]], bamfiles: List[Any], + reference_seq_dict: Dict[str, str]) -> List[Dict[str, Any]]: """ - Process candidate regions to generate candidate windows with haplotypes. + Process candidate regions to generate window haplotyples with + realigned reads. Parameters ---------- @@ -483,15 +641,20 @@ def process_candidate_windows( bamfiles : List[Any] List of BAM file data. reference_seq_dict : Dict[str, str] - Dictionary with chromosome names as keys and reference sequences as values. + Dictionary with chromosome names as keys and reference + sequences as values. Returns ------- - List[Tuple[str, int, int, int, int, List[str]]] - List of candidate windows with haplotypes. + List[Dict[str, Any]] + List of dictionaries, where each dictionary represents a + candidate window and contains: + - 'span' : Tuple of (chromosome, start, end) + - 'haplotypes' : List of haplotypes (List[str]) + - 'realigned_reads' : List of realigned reads (List[Any]) """ - candidate_windows: List[Tuple[str, int, int, int, int, List[str]]] = [] + windows_haplotypes = [] for chrom, start, end, count in candidate_regions: window_reads = self.fetchreads(bamfiles, chrom, start, end) @@ -522,15 +685,29 @@ def process_candidate_windows( if candidate_haplotypes_list and candidate_haplotypes_list != [ ref_sequence ]: - candidate_windows.append((chrom, start, end, count, k, - candidate_haplotypes_list)) break if not found_graph: - candidate_windows.append( - (chrom, start, end, count, k, [ref_sequence])) + candidate_haplotypes_list = [ref_sequence] + + assembled_regions = [{ + "haplotypes": haplotypes, + "reads": [] + } for haplotypes in candidate_haplotypes_list] + realigned_reads = self.assign_reads_to_regions( + assembled_regions, window_reads) + + for assembled_region in assembled_regions: + aligned_reads = self.fast_pass_aligner(assembled_region) + realigned_reads.extend(aligned_reads) + + windows_haplotypes.append({ + 'span': (chrom, start, end), + 'haplotypes': candidate_haplotypes_list, + 'realigned_reads': realigned_reads + }) - return candidate_windows + return windows_haplotypes class RealignerFeaturizer(Featurizer): @@ -605,7 +782,7 @@ def _featurize(self, datapoint): fasta_dataset = fasta_loader.create_dataset(reference_file_path) one_hot_encoded_sequences = fasta_dataset.X - decoded_sequences: List[str] = [] + decoded_sequences = [] # Convert the one-hot encoded sequences to strings for seq in one_hot_encoded_sequences: @@ -615,7 +792,7 @@ def _featurize(self, datapoint): # Map the sequences to chrom names chrom_names = ["chr1", "chr2"] - reference_seq_dict: Dict[str, str] = { + reference_seq_dict = { chrom_names[i]: seq for i, seq in enumerate(decoded_sequences) } @@ -623,6 +800,6 @@ def _featurize(self, datapoint): bamfiles, reference_seq_dict) candidate_regions = self.realigner.select_candidate_regions( allele_counts) - candidate_windows = self.realigner.process_candidate_windows( + windows_haplotypes = self.realigner.process_candidate_windows( candidate_regions, bamfiles, reference_seq_dict) - return candidate_windows + return windows_haplotypes