Skip to content

Commit

Permalink
handle split reads in reform
Browse files Browse the repository at this point in the history
  • Loading branch information
hiruna534 committed Apr 25, 2024
1 parent 0649f1d commit 9bb5154
Show file tree
Hide file tree
Showing 5 changed files with 43,956 additions and 151 deletions.
298 changes: 151 additions & 147 deletions src/reform.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,157 +78,161 @@ def run(args):
stride = RNA_STRIDE
print("Info: Default stride: {}".format(stride))

samfile = pysam.AlignmentFile(args.bam, mode='r', check_sq=False)
count_split_reads = 0
count_duplex_reads = 0

fout = open(args.output, "w")
processed_sam_record_count = 0

for sam_record in samfile:

len_seq = len(sam_record.get_forward_sequence()) - kmer_length + 1 # to get the number of kmers

if not sam_record.has_tag("ns"):
raise Exception("Error: tag '{}' is not found. Please check your input SAM/BAM file.".format("ns"))
if not sam_record.has_tag("ts"):
raise Exception("Error: tag '{}' is not found. Please check your input SAM/BAM file.".format("ts"))
if not sam_record.has_tag("mv"):
raise Exception("Error: tag '{}' is not found. Please check your input SAM/BAM file.".format("mv"))

ns = int(sam_record.get_tag("ns"))
ts = int(sam_record.get_tag("ts"))
mv = sam_record.get_tag("mv")

# print("ns: " + str(ns))
# print("ts: " + str(ts))
# print(mv[:5])

len_mv = len(mv)
if len_mv == 0:
raise Exception("Error: mv array length is 0.")

if mv[0] != stride:
print("Info: Found stride to be {}, this value will be used.".format(mv[0]))
stride = int(mv[0])
if not args.c:
move_count = 0
i = 1
while move_count < sig_move_offset + 1:
value = mv[i]
if value == 1:
move_count += 1
i += 1
end_idx = ts + (i - 1) * stride
start_idx = end_idx - stride
kmer_idx = 0
if args.rna:
kmer_idx = len_seq - 1

while i < len_mv:
value = mv[i]
if len_seq > 0 and value == 1:
fout.write("{}\t".format(sam_record.query_name))
fout.write("{}\t".format(kmer_idx))
fout.write("{}\t".format(start_idx))
fout.write("{}\n".format(end_idx))
start_idx = end_idx
if args.rna:
kmer_idx -= 1
else:
kmer_idx += 1

len_seq -= 1
if len_seq > 0 and i == len_mv-1:
fout.write("{}\t".format(sam_record.query_name))
fout.write("{}\t".format(kmer_idx))
fout.write("{}\t".format(start_idx))
fout.write("{}\n".format(ns))
start_idx = ns
if args.rna:
kmer_idx -= 1
else:
kmer_idx += 1
len_seq -= 1

end_idx = end_idx + stride
i += 1
if len_seq != 0:
raise Exception("Error: error in the implementation. Please report the command with minimal reproducible data. Read_id: {}".format(sam_record.query_name));

# write paf format
else:
fout.write("{}\t".format(sam_record.query_name)) #1 read_id
fout.write("{}\t".format(ns)) #2 Raw signal length (number of samples)
move_count = 0
i = 1
start_idx = 0
kmer_idx = 0

while move_count < sig_move_offset + 1:
if i >= len(mv):
print(sam_record.query_name)
print("i >= len(mv)")
value = mv[i]
if value == 1:
move_count += 1
start_idx = i
i += 1

fout.write("{}\t".format(ts + (i - 2) * stride)) #3 Raw signal start index (0-based; BED-like; closed)

j = 1
l_end_raw = 0
len_seq_1 = len_seq + sig_move_offset + 1
end_idx = j + 1
while j < len_mv:
value = mv[j]
if len_seq_1 > 0 and value == 1:
len_seq_1 -= 1
end_idx = j
j += 1
if len_seq_1 > 0 and j == len_mv:
l_end_raw = ns
else:
l_end_raw = ts + (end_idx-1) * stride

fout.write("{}\t".format(l_end_raw)) #4 Raw signal end index (0-based; BED-like; open)
fout.write("+\t") #5 Relative strand: "+" or "-"
fout.write("{}\t".format(sam_record.query_name)) #6 Same as column 1
fout.write("{}\t".format(len_seq)) #7 base-called sequence length (no. of k-mers)

if args.rna:
fout.write("{}\t".format(len_seq)) # 8 k-mer start index on basecalled sequence (0-based)
fout.write("{}\t".format(kmer_idx)) # 9 k-mer end index on basecalled sequence (0-based)
else:
fout.write("{}\t".format(kmer_idx)) # 8 k-mer start index on basecalled sequence (0-based)
fout.write("{}\t".format(len_seq)) # 9 k-mer end index on basecalled sequence (0-based)
fout.write("{}\t".format(len_seq - kmer_idx)) #10 Number of k-mers matched on basecalled sequence
fout.write("{}\t".format(len_seq)) #11 Same as column 7
fout.write("{}\t".format("255")) #12 Mapping quality (0-255; 255 for missing)
fout.write("{}".format("ss:Z:")) #12 Mapping quality (0-255; 255 for missing)

while i < len_mv:
value = mv[i]
# print("{}\t{}".format(i, value))
if len_seq > 0 and value == 1:
fout.write("{},".format((i-start_idx) * stride)) # ss
start_idx = i
len_seq -= 1
if len_seq > 0 and i == len_mv-1:
if (ns - ((i-1) * stride + ts)) < 0:
raise Exception("Error: error in calcuation. (ns - ((i-1)*EXPECTED_STRIDE + ts)) > 0 is not valid")
len_seq -= 1
l_duration = ((i-start_idx) * stride) + (ns - ((i - 1) * stride + ts))
fout.write("{},".format(l_duration)) # ss
i += 1

if len_seq != 0:
raise Exception("Error: error in the implementation. Please report the command with minimal reproducible data. Read_id: {}".format(sam_record.query_name));

fout.write("{}".format("\n")) # newline
processed_sam_record_count += 1

samfile.close()
with pysam.AlignmentFile(args.bam, "rb", check_sq=False) as bam:
for sam_record in bam:
len_seq = len(sam_record.get_forward_sequence()) - kmer_length + 1 # to get the number of kmers
if sam_record.has_tag("sp"):
count_split_reads += 1
continue
if sam_record.has_tag("dx") and int(sam_record.get_tag("dx")) == 1:
count_duplex_reads += 1
continue

if not sam_record.has_tag("ns"):
raise Exception("Error: tag '{}' is not found. Please check your input SAM/BAM file.".format("ns"))
if not sam_record.has_tag("ts"):
raise Exception("Error: tag '{}' is not found. Please check your input SAM/BAM file.".format("ts"))
if not sam_record.has_tag("mv"):
raise Exception("Error: tag '{}' is not found. Please check your input SAM/BAM file.".format("mv"))

ns = int(sam_record.get_tag("ns"))
ts = int(sam_record.get_tag("ts"))
mv = sam_record.get_tag("mv")

# print("ns: " + str(ns))
# print("ts: " + str(ts))
# print(mv[:5])

len_mv = len(mv)
if len_mv == 0:
raise Exception("Error: mv array length is 0.")

if mv[0] != stride:
print("Info: Found stride to be {}, this value will be used.".format(mv[0]))
stride = int(mv[0])
if not args.c:
move_count = 0
i = 1
while move_count < sig_move_offset + 1:
value = mv[i]
if value == 1:
move_count += 1
i += 1
end_idx = ts + (i - 1) * stride
start_idx = end_idx - stride
kmer_idx = 0
if args.rna:
kmer_idx = len_seq - 1

while i < len_mv:
value = mv[i]
if len_seq > 0 and value == 1:
fout.write("{}\t".format(sam_record.query_name))
fout.write("{}\t".format(kmer_idx))
fout.write("{}\t".format(start_idx))
fout.write("{}\n".format(end_idx))
start_idx = end_idx
if args.rna:
kmer_idx -= 1
else:
kmer_idx += 1

len_seq -= 1
if len_seq > 0 and i == len_mv-1:
fout.write("{}\t".format(sam_record.query_name))
fout.write("{}\t".format(kmer_idx))
fout.write("{}\t".format(start_idx))
fout.write("{}\n".format(ns))
start_idx = ns
if args.rna:
kmer_idx -= 1
else:
kmer_idx += 1
len_seq -= 1

end_idx = end_idx + stride
i += 1
if len_seq != 0:
raise Exception("Error: error in the implementation. Please report the command with minimal reproducible data. Read_id: {}".format(sam_record.query_name));
else: # write paf format
fout.write("{}\t".format(sam_record.query_name)) #1 read_id
fout.write("{}\t".format(ns)) #2 Raw signal length (number of samples)
move_count = 0
i = 1
start_idx = 0
kmer_idx = 0

while move_count < sig_move_offset + 1:
if i >= len(mv):
print(sam_record.query_name)
print("i >= len(mv)")
value = mv[i]
if value == 1:
move_count += 1
start_idx = i
i += 1

fout.write("{}\t".format(ts + (i - 2) * stride)) #3 Raw signal start index (0-based; BED-like; closed)

j = 1
l_end_raw = 0
len_seq_1 = len_seq + sig_move_offset + 1
end_idx = j + 1
while j < len_mv:
value = mv[j]
if len_seq_1 > 0 and value == 1:
len_seq_1 -= 1
end_idx = j
j += 1
if len_seq_1 > 0 and j == len_mv:
l_end_raw = ns
else:
l_end_raw = ts + (end_idx-1) * stride

fout.write("{}\t".format(l_end_raw)) #4 Raw signal end index (0-based; BED-like; open)
fout.write("+\t") #5 Relative strand: "+" or "-"
fout.write("{}\t".format(sam_record.query_name)) #6 Same as column 1
fout.write("{}\t".format(len_seq)) #7 base-called sequence length (no. of k-mers)

if args.rna:
fout.write("{}\t".format(len_seq)) # 8 k-mer start index on basecalled sequence (0-based)
fout.write("{}\t".format(kmer_idx)) # 9 k-mer end index on basecalled sequence (0-based)
else:
fout.write("{}\t".format(kmer_idx)) # 8 k-mer start index on basecalled sequence (0-based)
fout.write("{}\t".format(len_seq)) # 9 k-mer end index on basecalled sequence (0-based)
fout.write("{}\t".format(len_seq - kmer_idx)) #10 Number of k-mers matched on basecalled sequence
fout.write("{}\t".format(len_seq)) #11 Same as column 7
fout.write("{}\t".format("255")) #12 Mapping quality (0-255; 255 for missing)
fout.write("{}".format("ss:Z:")) #12 Mapping quality (0-255; 255 for missing)

while i < len_mv:
value = mv[i]
# print("{}\t{}".format(i, value))
if len_seq > 0 and value == 1:
fout.write("{},".format((i-start_idx) * stride)) # ss
start_idx = i
len_seq -= 1
if len_seq > 0 and i == len_mv-1:
if (ns - ((i-1) * stride + ts)) < 0:
raise Exception("Error: error in calcuation. (ns - ((i-1)*EXPECTED_STRIDE + ts)) > 0 is not valid")
len_seq -= 1
l_duration = ((i-start_idx) * stride) + (ns - ((i - 1) * stride + ts))
fout.write("{},".format(l_duration)) # ss
i += 1

if len_seq != 0:
raise Exception("Error: error in the implementation. Please report the command with minimal reproducible data. Read_id: {}".format(sam_record.query_name));

fout.write("{}".format("\n")) # newline
processed_sam_record_count += 1
fout.close()
print("processed_sam_record_count: " + str(processed_sam_record_count))
print("skipped: split reads ({}), duplex reads ({})".format(count_split_reads, count_duplex_reads))


def argparser():
Expand Down
Loading

0 comments on commit 9bb5154

Please sign in to comment.