diff --git a/transcript_transformer/__init__.py b/transcript_transformer/__init__.py index e69de29..f15a403 100644 --- a/transcript_transformer/__init__.py +++ b/transcript_transformer/__init__.py @@ -0,0 +1,276 @@ +# Define global variables + +CDN_PROT_DICT = { + "ATA": "I", + "ATC": "I", + "ATT": "I", + "ATG": "M", + "ACA": "T", + "ACC": "T", + "ACG": "T", + "ACT": "T", + "AAC": "N", + "AAT": "N", + "AAA": "K", + "AAG": "K", + "AGC": "S", + "AGT": "S", + "AGA": "R", + "AGG": "R", + "CTA": "L", + "CTC": "L", + "CTG": "L", + "CTT": "L", + "CCA": "P", + "CCC": "P", + "CCG": "P", + "CCT": "P", + "CAC": "H", + "CAT": "H", + "CAA": "Q", + "CAG": "Q", + "CGA": "R", + "CGC": "R", + "CGG": "R", + "CGT": "R", + "GTA": "V", + "GTC": "V", + "GTG": "V", + "GTT": "V", + "GCA": "A", + "GCC": "A", + "GCG": "A", + "GCT": "A", + "GAC": "D", + "GAT": "D", + "GAA": "E", + "GAG": "E", + "GGA": "G", + "GGC": "G", + "GGG": "G", + "GGT": "G", + "TCA": "S", + "TCC": "S", + "TCG": "S", + "TCT": "S", + "TTC": "F", + "TTT": "F", + "TTA": "L", + "TTG": "L", + "TAC": "Y", + "TAT": "Y", + "TAA": "_", + "TAG": "_", + "TGC": "C", + "TGT": "C", + "TGA": "_", + "TGG": "W", + "NNN": "_", +} + +PROT_IDX_DICT = { + "A": 0, + "R": 1, + "N": 2, + "D": 3, + "C": 4, + "E": 5, + "Q": 6, + "G": 7, + "H": 8, + "O": 9, + "I": 10, + "L": 11, + "K": 12, + "M": 13, + "F": 14, + "P": 15, + "S": 16, + "T": 17, + "W": 18, + "Y": 19, + "V": 20, +} + +DNA_IDX_DICT = { + "A": 0, + "T": 1, + "C": 2, + "G": 3, + "N": 4, +} + +IDX_PROT_DICT = {v: k for k, v in PROT_IDX_DICT.items()} +IDX_DNA_DICT = {v: k for k, v in DNA_IDX_DICT.items()} + +STANDARD_HEADERS = [ + "CDS_coords", + "CDS_idxs", + "canonical_TIS_coord", + "canonical_TIS_exon", + "canonical_TIS_idx", + "canonical_LTS_coord", + "canonical_LTS_idx", + "canonical_TTS_coord", + "canonical_TTS_idx", + "canonical_protein_seq", + "exon_coords", + "exon_idxs", + "gene_id", + "has_annotated_start_codon", + "has_annotated_stop_codon", + "seq", + "seqname", + "source", + "strand", + "transcript_id", + "transcript_len", +] + +RENAME_HEADERS = { + "has_annotated_start_codon": "CDS_has_annotated_start_codon", + "has_annotated_stop_codon": "CDS_has_annotated_stop_codon", +} + +STANDARD_OUT_HEADERS = [ + "seqname", + "ORF_id", + "ORF_len", + "transcript_id", + "transcript_len", + "start_codon", + "stop_codon", + "strand", + "ORF_type", + "TIS_pos", + "TTS_pos", + "has_CDS_clones", + "has_CDS_TIS", + "has_CDS_TTS", + "shared_in_frame_CDS_frac", + "dist_from_canonical_TIS", + "frame_wrt_canonical_TIS", + "TTS_on_transcript", + "TIS_coord", + "TIS_exon", + "TTS_coord", + "TTS_exon", + "LTS_coord", + "LTS_exon", + "gene_id", + "canonical_TIS_coord", + "canonical_TIS_pos", + "canonical_LTS_coord", + "canonical_LTS_pos", + "canonical_TTS_coord", + "canonical_TTS_pos", + "has_annotated_start_codon", + "has_annotated_stop_codon", + "protein_seq", +] + +RIBO_OUT_HEADERS = [ + "correction", + "reads_in_transcript", + "reads_in_ORF", + "reads_in_frame_frac", + "reads_5UTR", + "reads_3UTR", + "reads_coverage_frac", + "reads_entropy", + "reads_skew", +] + + +RIBOTIE_MQC_HEADER = """ +# parent_id: 'ribotie' +# parent_name: "RiboTIE" +# parent_description: "Overview of open reading frames called as translating by RiboTIE" +# """ + +START_CODON_MQC_HEADER = """ +# id: 'ribotie_start_codon_counts' +# section_name: 'Start Codon' +# description: "Start codon counts of all open reading frames called by RiboTIE" +# plot_type: 'bargraph' +# anchor: 'orf_start_codon_counts' +# pconfig: +# id: "orf_start_codon_counts_plot" +# title: "RiboTIE: Start Codons" +# colors: +# ATG : "#f8d7da" +# xlab: "# ORFs" +# cpswitch_counts_label: "Number of ORFs" +""" + +BIOTYPE_VARIANT_MQC_HEADER = """ +# id: 'ribotie_biotype_counts_variant' +# section_name: 'Transcript Biotypes (varRNA-ORF)' +# description: "Transcript biotypes of varRNA-ORFs called by RiboTIE" +# plot_type: 'bargraph' +# anchor: 'transcript_biotype_variant_counts' +# pconfig: +# id: "transcript_biotype_counts_variant_plot" +# title: "RiboTIE: varRNA-ORFs Transcript Biotypes" +# xlab: "# ORFs" +# cpswitch_counts_label: "Number of ORFs" +""" + +ORF_TYPE_MQC_HEADER = """ +# id: 'ribotie_orftype_counts' +# section_name: 'ORF types' +# description: "ORF types of all open reading frames called by RiboTIE" +# plot_type: 'bargraph' +# anchor: 'transcript_orftype_counts' +# pconfig: +# id: "transcript_orftype_counts_plot" +# title: "RiboTIE: ORF Types" +# xlab: "# ORFs" +# cpswitch_counts_label: "Number of ORFs" +""" + +ORF_LEN_MQC_HEADER = """ +# id: 'ribotie_orflen_hist' +# section_name: 'ORF lengths' +# description: "ORF lengths of all open reading frames called by RiboTIE" +# plot_type: 'linegraph' +# anchor: 'transcript_orflength_hist' +# pconfig: +# id: "transcript_orflength_hist_plot" +# title: "RiboTIE: ORF lengths" +# xlab: "Length" +# xLog: "True" +""" + +ORF_TYPE_ORDER = [ + "annotated CDS", + "N-terminal truncation", + "N-terminal extension", + "C-terminal truncation", + "C-terminal extension", + "uORF", + "uoORF", + "dORF", + "doORF", + "intORF", + "lncRNA-ORF", + "varRNA-ORF", +] + +ORF_BIOTYPE_ORDER = [ + "retained_intron", + "protein_coding", + "protein_coding_CDS_not_defined", + "nonsense_mediated_decay", + "processed_pseudogene", + "unprocessed_pseudogene", + "transcribed_unprocessed_pseudogene", + "transcribed_processed_pseudogene", + "translated_processed_pseudogene", + "transcribed_unitary_pseudogene", + "processed_transcript", + "TEC", + "artifact", + "non_stop_decay", + "misc_RNA", +] diff --git a/transcript_transformer/argparser.py b/transcript_transformer/argparser.py index 1ca4e32..8bdeccb 100644 --- a/transcript_transformer/argparser.py +++ b/transcript_transformer/argparser.py @@ -271,7 +271,7 @@ def add_train_loading_args(self, pretrain=False, auto=False): nargs="*", default=[], help="chromosomes used for training. If not specified, " - "training is performed on all available chromosomes excluding val/test contigs", + "training is performed on all available chromosomes excluding val/test seqnames", ) dl_parse.add_argument( "--val", @@ -483,7 +483,7 @@ def parse_arguments(self, argv, configs=[]): args.model_dir = model_dir # create output dir if non-existent if args.out_prefix: - os.makedirs(os.path.dirname(args.out_prefix), exist_ok=True) + os.makedirs(os.path.dirname(os.path.abspath(args.out_prefix)), exist_ok=True) # backward compatibility if "seq" in args: args.use_seq = args.seq @@ -517,7 +517,7 @@ def parse_arguments(self, argv, configs=[]): # conditions used to remove transcripts from training/validation data conds = {"global": {}, "grouped": [{} for l in range(len(args.ribo_ids))]} - conds["global"]["tr_len"] = lambda x: np.logical_and( + conds["global"]["transcript_len"] = lambda x: np.logical_and( x > args.min_seq_len, x < args.max_seq_len ) if args.cond is not None: @@ -557,7 +557,7 @@ def parse_arguments(self, argv, configs=[]): # Default values args.exp_path = "transcript" args.y_path = "tis" - args.seqn_path = "contig" - args.id_path = "id" + args.seqn_path = "seqname" + args.id_path = "transcript_id" print(args) return args diff --git a/transcript_transformer/configs/riboformer_defaults.yml b/transcript_transformer/configs/ribotie_defaults.yml similarity index 92% rename from transcript_transformer/configs/riboformer_defaults.yml rename to transcript_transformer/configs/ribotie_defaults.yml index d2f4694..16009a4 100644 --- a/transcript_transformer/configs/riboformer_defaults.yml +++ b/transcript_transformer/configs/ribotie_defaults.yml @@ -23,4 +23,5 @@ patience: 8 cond : ribo: num_reads : "x > 6" + has_annotated_start_codon: "x" diff --git a/transcript_transformer/configs/tis_transformer_defaults.yml b/transcript_transformer/configs/tis_transformer_defaults.yml index 2252a96..61d71e0 100644 --- a/transcript_transformer/configs/tis_transformer_defaults.yml +++ b/transcript_transformer/configs/tis_transformer_defaults.yml @@ -15,4 +15,6 @@ ff_glu: false emb_dropout: 0.1 ff_dropout: 0.1 attn_dropout: 0.1 -local_window_size: 256 \ No newline at end of file +local_window_size: 256 +cond: + has_annotated_start_codon: "x" \ No newline at end of file diff --git a/transcript_transformer/data.py b/transcript_transformer/data.py index a3851b6..c0e76cc 100644 --- a/transcript_transformer/data.py +++ b/transcript_transformer/data.py @@ -16,46 +16,58 @@ import pyfaidx from gtfparse import read_gtf -from .util_functions import vec2DNA, construct_prot - - -def co_to_idx(start, end): - return start - 1, end - - -def slice_gen( - seq, - start, - end, - strand, - co=True, - to_vec=True, - seq_dict={"A": 0, "T": 1, "C": 2, "G": 3, "N": 4}, - comp_dict={0: 1, 1: 0, 2: 3, 3: 2, 4: 4}, -): - """get sequence following gtf-coordinate system""" - if co: - start, end = co_to_idx(start, end) - sl = seq[start:end].seq - - if to_vec: - sl = list(map(lambda x: seq_dict[x.upper()], sl)) - - if strand in ["-", -1, False]: - if comp_dict is not None: - sl = list(map(lambda x: comp_dict[x], sl))[::-1] - else: - sl = sl[::-1] - - return np.array(sl) +from .util_functions import vec2DNA, construct_prot, time, slice_gen, prot2vec + +REQ_HEADERS = [ + "seqname", + "feature", + "start", + "end", + "strand", + "gene_id", + "transcript_id", + "exon_number", +] +CUSTOM_HEADERS = [ + "transcript_id", + "seq", + "tis", + "canonical_TIS_exon", + "exon_idxs", + "exon_coords", + "CDS_idxs", + "CDS_coords", + "has_annotated_start_codon", + "has_annotated_stop_codon", + "canonical_TIS_idx", + "canonical_TIS_coord", + "canonical_TTS_idx", + "canonical_TTS_coord", + "canonical_LTS_idx", + "canonical_LTS_coord", + "transcript_len", + "canonical_protein_seq", +] +DROPPED_HEADERS = [ + "end", + "exon_id", + "exon_version", + "exon_number", + "feature", + "frame", + "score", + "start", +] def process_seq_data(h5_path, gtf_path, fa_path, backup_path, backup=True): pulled = False if not backup_path: backup_path = os.path.splitext(gtf_path)[0] + ".h5" - if os.path.abspath(backup_path) == os.path.abspath(os.path.dirname(h5_path)): - print(f"!-> Backup path identical to h5 output path, disabling backup...") + if os.path.abspath(backup_path) == os.path.abspath(h5_path): + print( + f"!-> Backup path identical to h5 output path, no database copy will be created..." + ) backup = False elif not os.path.isfile(h5_path) and os.path.isfile(backup_path): print(f"--> Processed assembly data restored ({backup_path})") @@ -70,7 +82,7 @@ def process_seq_data(h5_path, gtf_path, fa_path, backup_path, backup=True): ) f.close() else: - data_dict = parse_transcriptome(gtf_path, fa_path) + db = parse_transcriptome(gtf_path, fa_path) no_handle = True max_wait = 900 waited = 0 @@ -84,7 +96,7 @@ def process_seq_data(h5_path, gtf_path, fa_path, backup_path, backup=True): waited += 120 if not no_handle: try: - f = save_transcriptome_to_h5(f, data_dict) + f = save_transcriptome_to_h5(f, db) f.close() if backup and (not pulled): shutil.copy(h5_path, backup_path) @@ -99,9 +111,11 @@ def process_ribo_data( h5_path, ribo_paths, overwrite=False, parallel=False, low_memory=False ): f = h5py.File(h5_path, "r") - tr_ids = pl.from_numpy(np.array(f["transcript/id"])).to_series().cast(pl.Utf8) - tr_lens = pl.from_numpy(np.array(f["transcript/tr_len"])).to_series() - header_dict = {2: "tr_ID", 3: "pos", 9: "read"} + tr_ids = ( + pl.from_numpy(np.array(f["transcript/transcript_id"])).to_series().cast(pl.Utf8) + ) + tr_lens = pl.from_numpy(np.array(f["transcript/transcript_len"])).to_series() + header_dict = {2: "transcript_id", 3: "pos", 9: "read"} ribo_to_parse = deepcopy(ribo_paths) for experiment, path in ribo_paths.items(): cond_1 = ( @@ -166,229 +180,195 @@ def process_ribo_data( del f[f"transcript/riboseq/{experiment}"] -def save_transcriptome_to_h5(f, data_dict): +def save_transcriptome_to_h5(f, db): print("Save data in hdf5 files...") dt8 = h5py.vlen_dtype(np.dtype("int8")) dt = h5py.vlen_dtype(np.dtype("int")) grp = f.create_group("transcript") - for key, array in data_dict.items(): - if key in [ - "id", - "contig", - "gene_id", - "gene_name", - "strand", - "biotype", - "tag", - "support_lvl", - "canonical_protein_id", - ]: - if key != "id": - array = [a if a != None else "" for a in array] - grp.create_dataset( - key, data=array, dtype=f" 0: + grp.create_dataset(key, data=array, dtype=f" Loading assembly data...") genome = pyfaidx.Fasta(fa_path) contig_list = pl.Series(genome.keys()) gtf = read_gtf(gtf_path, result_type="polars") - gtf = gtf.with_columns(pl.col("exon_number").cast(pl.Int32, strict=False)) - headers = { - "transcript_id": "id", - "gene_id": "gene_id", - "gene_name": "gene_name", - "strand": "strand", - "transcript_biotype": "biotype", - "tag": "tag", - "transcript_support_level": "support_lvl", - } - - print("Extracting transcripts and metadata...") - data_dict = { - "id": [], - "seq": [], - "tis": [], - "gene_id": [], - "gene_name": [], - "contig": [], - "strand": [], - "biotype": [], - "tag": [], - "support_lvl": [], - "canonical_TIS_exon_idx": [], - "exon_idxs": [], - "exon_coords": [], - "CDS_idxs": [], - "CDS_coords": [], - "canonical_TIS_idx": [], - "canonical_TIS_coord": [], - "canonical_TTS_idx": [], - "canonical_TTS_coord": [], - "tr_len": [], - "canonical_protein_id": [], - "canonical_protein_seq": [], - } - assert "transcript_id" in gtf.columns, "transcript_id column missing in gtf file" - assert "strand" in gtf.columns, "strand column missing in gtf file" - for key, _ in headers.items(): - if key not in gtf.columns: - gtf = gtf.with_columns(pl.lit("").alias(key)) + # use biobear instead + # session = bb.connect() + # gtf = session.sql(f"SELECT * FROM gtf_scan('{gtf_path}')").to_polars() - for contig in contig_list: - print(f"{contig}...") - gtf_set = gtf.filter(pl.col("seqname") == str(contig)) - tr_set = gtf_set["transcript_id"].unique() - tr_set = tr_set.filter(tr_set != "") - for i, id in tqdm(enumerate(tr_set), total=len(tr_set)): - # obtain transcript information - gtf_tr = gtf_set.filter(pl.col("transcript_id") == id).sort( - by="exon_number" - ) - keys, values = ( - ( - gtf_tr.filter(pl.col("feature") == "transcript").select( - list(headers.keys()) - ) - ) - .melt() - .to_dict() - .values() - ) - for k, i in zip(list(headers.values()), values): - data_dict[k].append(i) - - # obtain and sort exon information (strings have wrong sortin (e.g. 10, 11, 2, 3, ...)) - exons = gtf_tr.filter(pl.col("feature") == "exon") - exon_lens = (abs(exons["start"] - exons["end"]) + 1).to_numpy() - cum_exon_lens = np.insert(np.cumsum(exon_lens), 0, 0) - cdss = gtf_tr.filter(pl.col("feature") == "CDS") - if len(cdss) > 0: - cds_lens = (abs(cdss["start"] - cdss["end"]) + 1).to_numpy() - cum_cds_lens = np.insert(np.cumsum(cds_lens), 0, 0) - cds_idxs = np.vstack((cum_cds_lens[:-1], cum_cds_lens[1:])).T.ravel() - - strand_is_pos = (exons["strand"] == "+").any() - data_dict["tr_len"].append(exon_lens.sum()) + # import exon number as int (strings have wrong sortin (e.g. 10, 11, 2,...)) + gtf = gtf.with_columns(pl.col("exon_number").cast(pl.Int32, strict=False)) + # ensure all required fields are listed + assert np.isin( + REQ_HEADERS, gtf.columns + ).all(), f"Not all required properties in gtf file: {REQ_HEADERS}" + # evaluate extra columns + xtr_cols = np.array(gtf.columns)[ + ~pl.Series(gtf.columns).is_in(REQ_HEADERS).to_numpy() + ] + data_dict_keys = np.array(REQ_HEADERS + CUSTOM_HEADERS + list(xtr_cols)) + data_dict = {k: [] for k in CUSTOM_HEADERS} + data_cols_in_gtf = data_dict_keys[np.isin(data_dict_keys, gtf.columns)] + + print("--> Importing transcripts and metadata...") + gtf_set = gtf.filter( + # exclude transcript ids that are empty + pl.col("transcript_id") != "", + pl.col("feature").is_in( + ["transcript", "exon", "CDS", "start_codon", "stop_codon"] + ), + ).sort(["seqname", "transcript_id", "exon_number"]) + gtf_set = gtf_set.with_columns( + (abs(pl.col("start") - pl.col("end")) + 1).alias("feature_length") + ) + trs = gtf_set["transcript_id"].unique(maintain_order=True) - if len(exons) == 0: - print( - "WARNING: No exons found for transcript. This should not happen. Please ensure" - "exons are marked with the correct transcript id" - ) + db = pl.DataFrame(data={"transcript_id": trs}) + db = db.join(gtf.filter(pl.col("feature") == "transcript"), on="transcript_id") - # obtain TISs, select first in case of split (intron) start codon - # TODO: when multiple TISs are supported, code needs update - start_codon = ( - gtf_tr.filter(pl.col("feature") == "start_codon").slice(0, 1).to_dicts() - ) - stop_codon = ( - gtf_tr.filter(pl.col("feature") == "stop_codon").slice(0, 1).to_dicts() + for tr_id, gtf_tr in tqdm( + gtf_set.group_by("transcript_id", maintain_order=True), total=len(db) + ): + is_pos_strand = (gtf_tr["strand"] == "+").any() + ftrs = {} + ftr_cum_lens = {} + ftr_idxs = {} + for feature, feature_df in gtf_tr.group_by("feature", maintain_order=True): + ftrs[feature[0]] = feature_df + ftr_lens = feature_df["feature_length"].drop_nulls().to_numpy() + cum_lens = np.insert(np.cumsum(ftr_lens), 0, 0) + ftr_cum_lens[feature[0]] = cum_lens + # feature boundaries; tuples flattened into single vector (e.g. [0,10,10,12,12,20]) + ftr_idxs[feature[0]] = np.vstack((cum_lens[:-1], cum_lens[1:])).T.ravel() + + data_dict["transcript_len"].append(ftr_cum_lens["exon"].max()) + if ftr_cum_lens["exon"].max() == 0: + print( + "WARNING: No exons found for transcript. This should not happen. Please ensure" + "exons are marked with the correct transcript id" ) - - target_seq = np.full(exon_lens.sum(), False) - - if len(start_codon) > 0: - # use as index for sorted dfs - start_codon = start_codon[0] - exon_i = start_codon["exon_number"] - 1 - exon = exons[exon_i].to_dicts()[0] - if strand_is_pos: - tis = start_codon["start"] - tis_idx = cum_exon_lens[exon_i] + tis - exon["start"] - if len(stop_codon) > 0: - tts = stop_codon[0]["start"] - else: - tts = exons[-1].to_dicts()[0]["end"] + # TODO: when multiple TISs are supported, code needs update + # init empty boolean to denote TIS locations + target_seq = np.full(ftr_cum_lens["exon"].max(), False) + + exon_coords = [] + exon_seqs = [] + for exon_i, exon in enumerate(ftrs["exon"].iter_rows(named=True)): + # get sequence + exon_seq = slice_gen( + genome[exon["seqname"]], + exon["start"], + exon["end"], + exon["strand"], + to_vec=True, + ).astype(np.int16) + exon_coords.append(exon["start"]) + exon_coords.append(exon["end"]) + exon_seqs.append(exon_seq) + seq = np.concatenate(exon_seqs) + + if "CDS" in ftrs: + # select first in case of split (intron) start codon + first_cds = ftrs["CDS"][0].to_dicts()[0] + exon_i = first_cds["exon_number"] - 1 + exon = ftrs["exon"][exon_i].to_dicts()[0] + # shift CDS transcript idxs based on start exon + exon_shift = ftr_cum_lens["exon"][exon_i] + if is_pos_strand: + # shift CDS transcript idxs based on cds start in exon + in_exon_shift = ftrs["CDS"][0, "start"] - exon["start"] + tis = first_cds["start"] + tis_idx = ftr_cum_lens["exon"][exon_i] + tis - exon["start"] + lts = ftrs["CDS"][-1].to_dicts()[0]["end"] + if "stop_codon" in ftrs: + tts = ftrs["stop_codon"][0][0, "start"] else: - tis = start_codon["end"] - tis_idx = cum_exon_lens[exon_i] + exon["end"] - tis - if len(stop_codon) > 0: - tts = stop_codon[0]["end"] - else: - tts = exons[-1].to_dicts()[0]["start"] - - target_seq[tis_idx] = 1 - data_dict["canonical_TIS_exon_idx"].append(exon_i) - data_dict["canonical_TIS_idx"].append(tis_idx) - data_dict["canonical_TTS_idx"].append(tis_idx + sum(cds_lens)) - # remove potential empty entries from protein_id ("") - prot_ids = ( - gtf_tr["protein_id"].to_frame().filter(pl.all() != "").to_series() - ) - prot_id = prot_ids.unique(maintain_order=True)[0] - data_dict["canonical_protein_id"].append(prot_id) - data_dict["canonical_TIS_coord"].append(tis) - data_dict["canonical_TTS_coord"].append(tts) - + tts = -1 else: - data_dict["canonical_TIS_exon_idx"].append(-1) - data_dict["canonical_TIS_idx"].append(-1) - data_dict["canonical_TTS_idx"].append(-1) - data_dict["canonical_protein_id"].append("") - data_dict["canonical_TIS_coord"].append(-1) - data_dict["canonical_TTS_coord"].append(-1) - # some transcripts have CDSs but no start codons... - - if len(cdss) > 0: - if strand_is_pos: - exon_i = cdss[0, "exon_number"] - 1 - exon_shift = cum_exon_lens[exon_i] - cds_offset = exon_shift + cdss[0, "start"] - exons[exon_i, "start"] + # shift CDS transcript idxs based on cds start in exon + in_exon_shift = exon["end"] - ftrs["CDS"][0, "end"] + tis = first_cds["end"] + tis_idx = ftr_cum_lens["exon"][exon_i] + exon["end"] - tis + lts = ftrs["CDS"][-1].to_dicts()[0]["start"] + if "stop_codon" in ftrs: + tts = ftrs["stop_codon"][0][0, "end"] else: - exon_i = cdss[-1, "exon_number"] - 1 - exon_shift = cum_exon_lens[exon_i] - cds_offset = exon_shift + exons[exon_i, "end"] - cdss[-1, "end"] - data_dict["CDS_idxs"].append(cds_idxs + cds_offset) - data_dict["CDS_coords"].append( - cdss[:, ["start", "end"]].transpose().melt()["value"].to_numpy() - ) - else: - data_dict["CDS_idxs"].append(np.empty(0, dtype=int)) - data_dict["CDS_coords"].append(np.empty(0, dtype=int)) - - exon_coords = [] - exon_seqs = [] - for exon_i, exon in enumerate(exons.iter_rows(named=True)): - # get sequence - exon_seq = slice_gen( - genome[contig], - exon["start"], - exon["end"], - exon["strand"], - to_vec=True, - ).astype(np.int16) - - exon_coords.append(exon["start"]) - exon_coords.append(exon["end"]) - exon_seqs.append(exon_seq) - exon_idxs = np.vstack((cum_exon_lens[:-1], cum_exon_lens[1:])).T.ravel() - seq = np.concatenate(exon_seqs) - - if len(start_codon) > 0: - DNA_frag = vec2DNA(seq[data_dict["canonical_TIS_idx"][-1] :]) - prot, _, _ = construct_prot(DNA_frag) - else: - prot = "" - - data_dict["exon_idxs"].append(exon_idxs) - data_dict["exon_coords"].append(np.array(exon_coords)) - data_dict["seq"].append(seq) - data_dict["tis"].append(target_seq) + tts = -1 + target_seq[tis_idx] = 1 + DNA_frag = vec2DNA(seq[tis_idx:]) + prot, _, _ = construct_prot(DNA_frag) + data_dict["has_annotated_stop_codon"].append("stop_codon" in ftrs) + data_dict["has_annotated_start_codon"].append("start_codon" in ftrs) + data_dict["CDS_idxs"].append(ftr_idxs["CDS"] + exon_shift + in_exon_shift) + data_dict["CDS_coords"].append( + ftrs["CDS"][:, ["start", "end"]] + .transpose() + .unpivot()["value"] + .to_numpy() + ) + data_dict["canonical_TIS_exon"].append(exon_i + 1) + data_dict["canonical_TIS_idx"].append(tis_idx) + # LTS: Last Translation Site; 1 nucleotide upstream of TTS + tts_idx = tis_idx + ftr_cum_lens["CDS"].max() + data_dict["canonical_TTS_idx"].append(tts_idx) + data_dict["canonical_LTS_idx"].append(tts_idx - 1) + data_dict["canonical_TIS_coord"].append(tis) + data_dict["canonical_TTS_coord"].append(tts) + data_dict["canonical_LTS_coord"].append(lts) data_dict["canonical_protein_seq"].append(prot) - data_dict["contig"].append(contig) + else: + data_dict["has_annotated_stop_codon"].append(False) + data_dict["has_annotated_start_codon"].append(False) + data_dict["CDS_idxs"].append(np.empty(0, dtype=int)) + data_dict["CDS_coords"].append(np.empty(0, dtype=int)) + data_dict["canonical_TIS_exon"].append(-1) + data_dict["canonical_TIS_idx"].append(-1) + data_dict["canonical_TTS_idx"].append(-1) + data_dict["canonical_LTS_idx"].append(-1) + data_dict["canonical_TIS_coord"].append(-1) + data_dict["canonical_TTS_coord"].append(-1) + data_dict["canonical_LTS_coord"].append(-1) + data_dict["canonical_protein_seq"].append(None) + data_dict["exon_idxs"].append(ftr_idxs["exon"]) + data_dict["exon_coords"].append(np.array(exon_coords)) + data_dict["seq"].append(seq) + data_dict["tis"].append(target_seq) + data_dict["transcript_id"].append(gtf_tr["transcript_id"].unique()[0]) + + db_ext = pl.from_dict(data_dict) + db = db_ext.join(db, on="transcript_id", how="left") + # drop exon info that is not correct at transcript-level + db = db.drop(DROPPED_HEADERS, strict=False) + # vectorize protein sequences (less storage) + db = db.with_columns( + pl.col("canonical_protein_seq") + .fill_null("") + .map_elements( + prot2vec, + pl.List(pl.Int64), + ) + .cast(pl.List(pl.Int8)) + ) - return data_dict + return db def parse_ribo_reads(df, read_lens, f_ids, f_lens): @@ -397,8 +377,8 @@ def parse_ribo_reads(df, read_lens, f_ids, f_lens): print("Filtering on read lens...") df = df.with_columns(pl.col("read").str.len_chars().alias("read_len")) df = df.filter(pl.col("read_len").is_in(list(read_lens))) - df = df.sort("tr_ID") - id_lib = df["tr_ID"].unique(maintain_order=True) + df = df.sort("transcript_id") + id_lib = df["transcript_id"].unique(maintain_order=True) mask_f = f_ids.is_in(id_lib) print("Constructing empty datasets...") @@ -419,7 +399,7 @@ def parse_ribo_reads(df, read_lens, f_ids, f_lens): arg_sort = f_ids.arg_sort() h5_idxs = arg_sort[f_ids[arg_sort].search_sorted(id_lib)] for idx, (_, group) in tqdm( - zip(h5_idxs, df.group_by("tr_ID", maintain_order=True)), + zip(h5_idxs, df.group_by("transcript_id", maintain_order=True)), total=len(id_lib), ): tr_reads = np.zeros((num_read_lens, f_lens[idx]), dtype=np.int32) diff --git a/transcript_transformer/pretrained/riboformer_models/50perc_06_23.yml b/transcript_transformer/pretrained/ribotie_models/50perc_06_23.yml similarity index 100% rename from transcript_transformer/pretrained/riboformer_models/50perc_06_23.yml rename to transcript_transformer/pretrained/ribotie_models/50perc_06_23.yml diff --git a/transcript_transformer/pretrained/riboformer_models/50perc_06_23_f1.ckpt b/transcript_transformer/pretrained/ribotie_models/50perc_06_23_f1.ckpt similarity index 100% rename from transcript_transformer/pretrained/riboformer_models/50perc_06_23_f1.ckpt rename to transcript_transformer/pretrained/ribotie_models/50perc_06_23_f1.ckpt diff --git a/transcript_transformer/pretrained/riboformer_models/50perc_06_23_f2.ckpt b/transcript_transformer/pretrained/ribotie_models/50perc_06_23_f2.ckpt similarity index 100% rename from transcript_transformer/pretrained/riboformer_models/50perc_06_23_f2.ckpt rename to transcript_transformer/pretrained/ribotie_models/50perc_06_23_f2.ckpt diff --git a/transcript_transformer/processing.py b/transcript_transformer/processing.py index 7faa002..01eb81f 100644 --- a/transcript_transformer/processing.py +++ b/transcript_transformer/processing.py @@ -1,185 +1,78 @@ import numpy as np from tqdm import tqdm import h5py -import h5max import pandas as pd import polars as pl - +from scipy.stats import entropy +from scipy.sparse import csr_matrix + + +from transcript_transformer import ( + RIBOTIE_MQC_HEADER, + START_CODON_MQC_HEADER, + BIOTYPE_VARIANT_MQC_HEADER, + ORF_TYPE_MQC_HEADER, + ORF_TYPE_ORDER, + ORF_BIOTYPE_ORDER, + IDX_PROT_DICT, + IDX_DNA_DICT, + STANDARD_HEADERS, + RENAME_HEADERS, + STANDARD_OUT_HEADERS, + RIBO_OUT_HEADERS, +) from .util_functions import ( construct_prot, time, - vec2DNA, find_distant_exon_coord, transcript_region_to_exons, + get_str2str_idx_map, ) -HEADERS = [ - "id", - "contig", - "biotype", - "strand", - "canonical_TIS_coord", - "canonical_TIS_exon_idx", - "canonical_TIS_idx", - "canonical_TTS_coord", - "canonical_TTS_idx", - "canonical_protein_id", - "exon_coords", - "exon_idxs", - "gene_id", - "gene_name", - "support_lvl", - "tag", - "tr_len", -] - - -OUT_HEADERS = [ - "seqname", - "ORF_id", - "tr_id", - "TIS_pos", - "output", - "output_rank", - "seq_output", - "start_codon", - "stop_codon", - "ORF_len", - "TTS_pos", - "TTS_on_transcript", - "reads_in_tr", - "reads_in_ORF", - "reads_out_ORF", - "in_frame_read_perc", - "ORF_type", - "ORF_equals_CDS", - "tr_biotype", - "tr_support_lvl", - "tr_tag", - "tr_len", - "dist_from_canonical_TIS", - "frame_wrt_canonical_TIS", - "correction", - "TIS_coord", - "TIS_exon", - "TTS_coord", - "TTS_exon", - "strand", - "gene_id", - "gene_name", - "canonical_TIS_coord", - "canonical_TIS_exon_idx", - "canonical_TIS_idx", - "canonical_TTS_coord", - "canonical_TTS_idx", - "canonical_protein_id", - "protein_seq", -] - -DECODE = [ - "seqname", - "tr_id", - "tr_biotype", - "tr_support_lvl", - "tr_tag", - "strand", - "gene_id", - "gene_name", - "canonical_protein_id", -] - -RIBOTIE_MQC_HEADER = """ -# parent_id: 'ribotie' -# parent_name: "RiboTIE" -# parent_description: "Overview of open reading frames called as translating by RiboTIE" -# """ - -START_CODON_MQC_HEADER = """ -# id: 'ribotie_start_codon_counts' -# section_name: 'Start Codon' -# description: "Start codon counts of all open reading frames called by RiboTIE" -# plot_type: 'bargraph' -# anchor: 'orf_start_codon_counts' -# pconfig: -# id: "orf_start_codon_counts_plot" -# title: "RiboTIE: Start Codons" -# colors: -# ATG : "#f8d7da" -# xlab: "# ORFs" -# cpswitch_counts_label: "Number of ORFs" -""" - -BIOTYPE_VARIANT_MQC_HEADER = """ -# id: 'ribotie_biotype_counts_variant' -# section_name: 'Transcript Biotypes (CDS Variant)' -# description: "Transcript biotypes of 'CDS variant' (see ORF types) open reading frames called by RiboTIE" -# plot_type: 'bargraph' -# anchor: 'transcript_biotype_variant_counts' -# pconfig: -# id: "transcript_biotype_counts_variant_plot" -# title: "RiboTIE: Transcript Biotypes (CDS Variant)" -# xlab: "# ORFs" -# cpswitch_counts_label: "Number of ORFs" -""" -ORF_TYPE_MQC_HEADER = """ -# id: 'ribotie_orftype_counts' -# section_name: 'ORF types' -# description: "ORF types of all open reading frames called by RiboTIE" -# plot_type: 'bargraph' -# anchor: 'transcript_orftype_counts' -# pconfig: -# id: "transcript_orftype_counts_plot" -# title: "RiboTIE: ORF Types" -# xlab: "# ORFs" -# cpswitch_counts_label: "Number of ORFs" -""" - -ORF_LEN_MQC_HEADER = """ -# id: 'ribotie_orflen_hist' -# section_name: 'ORF lengths' -# description: "ORF lengths of all open reading frames called by RiboTIE" -# plot_type: 'linegraph' -# anchor: 'transcript_orflength_hist' -# pconfig: -# id: "transcript_orflength_hist_plot" -# title: "RiboTIE: ORF lengths" -# xlab: "Length" -# xLog: "True" -""" - -ORF_TYPE_ORDER = [ - "annotated CDS", - "N-terminal truncation", - "N-terminal extension", - "CDS variant", - "uORF", - "uoORF", - "dORF", - "doORF", - "intORF", - "lncRNA-ORF", - "other", -] - -ORF_BIOTYPE_ORDER = [ - "retained_intron", - "protein_coding", - "protein_coding_CDS_not_defined", - "nonsense_mediated_decay", - "processed_pseudogene", - "unprocessed_pseudogene", - "transcribed_unprocessed_pseudogene", - "transcribed_processed_pseudogene", - "translated_processed_pseudogene", - "transcribed_unitary_pseudogene", - "processed_transcript", - "TEC", - "artifact", - "non_stop_decay", - "misc_RNA", -] +def eval_overlap(ORF_id, CDS_exon_start, CDS_exon_end, ORF_exon_start, ORF_exon_end): + df = ( + pl.DataFrame( + [ORF_id, CDS_exon_start, CDS_exon_end, ORF_exon_start, ORF_exon_end] + ) + .with_columns( + overlap=(pl.col("ORF_exon_start") - pl.col("ORF_exon_end")).abs() + + 1 + - (pl.col("CDS_exon_start") - pl.col("ORF_exon_start")).clip(0) + - (pl.col("ORF_exon_end") - pl.col("CDS_exon_end")).clip(0), + ORF_exon_len=(pl.col("ORF_exon_end") - pl.col("ORF_exon_start")).abs() + 1, + ) + .group_by("ORF_exon_start", "ORF_exon_end") + .agg(pl.all().get(pl.col("overlap").arg_max())) + .with_columns( + ORF_coords_no_CDS=pl.when( + (pl.col("overlap") > 0) & (pl.col("overlap") < pl.col("ORF_exon_len")) + ) + .then( + pl.concat_list( + pl.when(pl.col("ORF_exon_start") < pl.col("CDS_exon_start")) + .then(pl.col("ORF_exon_start")) + .otherwise(pl.col("CDS_exon_end")), + pl.when(pl.col("ORF_exon_start") < pl.col("CDS_exon_start")) + .then(pl.col("CDS_exon_start")) + .otherwise(pl.col("ORF_exon_end")), + ) + ) + .otherwise(pl.lit([])) + ) + .group_by("ORF_id") + .agg( + pl.col("overlap").sum(), pl.col("ORF_coords_no_CDS").flatten().drop_nulls() + ) + ) + if len(df) > 0: + non_CDS_coords = df[0, "ORF_coords_no_CDS"].to_list() + overlap = df[0, "overlap"] + else: + non_CDS_coords = [] + overlap = 0 + return overlap, non_CDS_coords def construct_output_table( @@ -194,252 +87,614 @@ def construct_output_table( exclude_invalid_TTS=True, ribo=None, parallel=False, + unfiltered=False, ): - f = h5py.File(h5_path, "r")["transcript"] - f_tr_ids = np.array(f["id"]) - has_seq_output = "seq_output" in f.keys() + f = h5py.File(h5_path, "r") + f_tr_ids = np.array(f["transcript/transcript_id"]) + f_headers = pl.Series(f["transcript"].keys()) + f_headers = f_headers.filter(~f_headers.is_in(["riboseq", "tis"])) + xtr_heads = ( + pl.Series(f_headers) + .filter(~pl.Series(f_headers).is_in(STANDARD_HEADERS)) + .to_list() + ) + + has_tis_transformer_score = "tis_transformer_score" in f_headers has_ribo_output = ribo is not None - assert has_seq_output or has_ribo_output, "no model predictions found" + assert has_tis_transformer_score or has_ribo_output, "no model predictions found" print(f"--> Processing {out_prefix}...") if has_ribo_output: + prefix = "ribotie_" + tool_headers = ["ribotie_score", "ribotie_rank"] tr_ids = np.array([o[0].split(b"|")[1] for o in ribo]) ribo_id = ribo[0][0].split(b"|")[0] - xsorted = np.argsort(f_tr_ids) - pred_to_h5_args = xsorted[np.searchsorted(f_tr_ids[xsorted], tr_ids)] - preds = np.hstack([o[1] for o in ribo]) + pred_to_h5_args = get_str2str_idx_map(tr_ids, f_tr_ids) + preds = [o[1] for o in ribo] + df = pl.DataFrame( + { + "transcript_id": tr_ids, + "h5_idx": pred_to_h5_args, + f"{prefix}score": preds, + }, + strict=False, + ) + out_headers = tool_headers + STANDARD_OUT_HEADERS + RIBO_OUT_HEADERS + xtr_heads else: - mask = [len(o) > 0 for o in np.array(f["seq_output"])] - preds = np.hstack(np.array(f["seq_output"])) - pred_to_h5_args = np.where(mask)[0] - # map pred ids to database id - k = (preds > prob_cutoff).sum() - if k == 0: - print( - f"!-> No predictions with an output probability higher than {prob_cutoff}" + prefix = "tis_transformer_" + # bring columns to forefront in final output table + tool_headers = ["tis_transformer_score", "tis_transformer_rank"] + xtr_heads.remove("tis_transformer_score") + mask = [len(o) > 0 for o in np.array(f["transcript/tis_transformer_score"])] + df = pl.DataFrame( + { + "transcript_id": f["transcript/transcript_id"][:], + "h5_idx": np.where(mask)[0], + f"{prefix}score": f["transcript/tis_transformer_score"][:], + } + ).with_columns( + pl.col(f"{prefix}score").map_elements(list, pl.List(pl.Float64)), ) - return None - lens = np.array(f["tr_len"])[pred_to_h5_args] - cum_lens = np.cumsum(np.insert(lens, 0, 0)) - idxs = np.argpartition(preds, -k)[-k:] - - orf_dict = {"f_idx": [], "TIS_idx": []} - for idx in idxs: - idx_tr = np.where(cum_lens > idx)[0][0] - 1 - orf_dict["TIS_idx"].append(idx - cum_lens[idx_tr]) - orf_dict["f_idx"].append(pred_to_h5_args[idx_tr]) - if has_seq_output: - seq_out = [ - f["seq_output"][i][j] - for i, j in zip(orf_dict["f_idx"], orf_dict["TIS_idx"]) + out_headers = tool_headers + STANDARD_OUT_HEADERS + xtr_heads + print(f"{time()}: Parsing ORF information...") + df = df.with_columns( + TIS_idx=( + pl.col(f"{prefix}score") + .list.eval((pl.element() > prob_cutoff).arg_true()) + .cast(pl.List(pl.Int32)) + ), + corr_dist=pl.lit(dist * 3), + ).sort("h5_idx") + # filter transcripts with zero predictions + tr_mask = df["TIS_idx"].list.len() > 0 + assert tr_mask.any(), f"!-> No predictions higher than {prob_cutoff}" + df = df.filter(tr_mask) + df = df.with_columns( + [ + pl.Series(name=h, values=f[f"transcript/{h}"][:][df["h5_idx"]]) + for h in f_headers ] - orf_dict.update({"seq_output": seq_out}) - orf_dict.update( - {f"{header}": np.array(f[f"{header}"])[orf_dict["f_idx"]] for header in HEADERS} ) - orf_dict.update(orf_dict) - df_out = pd.DataFrame(data=orf_dict) - df_out = df_out.rename( - columns={ - "id": "tr_id", - "support_lvl": "tr_support_lvl", - "biotype": "tr_biotype", - "tag": "tr_tag", - } + # devectorize and Listify numpy arrays + df = df.with_columns( + ( + pl.col("canonical_protein_seq") + .map_elements(list, pl.List(pl.Int8)) + .list.eval(pl.element().replace_strict(IDX_PROT_DICT)) + .list.join("") + ), + ( + pl.col("seq") + .map_elements(list, pl.List(pl.Int8)) + .list.eval(pl.element().replace_strict(IDX_DNA_DICT)) + .list.join("") + ), + ).with_columns( + pl.col("exon_idxs").map_elements(list, pl.List(pl.Int64)), + pl.col("exon_coords").map_elements(list, pl.List(pl.Int64)), + pl.col("CDS_coords").map_elements(list, pl.List(pl.Int64)), + pl.col("CDS_idxs").map_elements(list, pl.List(pl.Int64)), + pl.col(f"{prefix}score").map_elements(list, pl.List(pl.Float64)), + pl.col(pl.Binary).cast(pl.String), ) - df_out["correction"] = np.nan - - df_dict = { - "start_codon": [], - "stop_codon": [], - "prot": [], - "TTS_exon": [], - "TTS_on_transcript": [], - "TIS_coord": [], - "TIS_exon": [], - "TTS_coord": [], - "TTS_coord": [], - "TTS_pos": [], - "ORF_len": [], + # filter transcripts with zero predictions + df = ( + df.with_columns( + (pl.col(f"{prefix}score").list.gather(pl.col("TIS_idx"))).alias( + f"{prefix}score" + ), + ) + .explode([f"{prefix}score", "TIS_idx"]) + .sort("h5_idx") + ) + # if non-canonical ATG, find in-frame ATGs to 'correct' prediction to + if correction: + df = ( + # 1. Calc upstream cut as multiple of 3 and not lower than 0 + df.with_columns( + upstr_cut=pl.col("corr_dist").clip( + 0, pl.col("TIS_idx") - pl.col("TIS_idx").mod(3) + ) + ) + # Create list of codons and check for in-frame ATG + .with_columns( + corrects=( + pl.col("seq") + .str.slice( + pl.col("TIS_idx") - pl.col("upstr_cut"), + pl.col("upstr_cut") + 3 + dist * 3, + ) + .map_elements( + lambda x: [x[i : i + 3] for i in range(0, len(x), 3)], + return_dtype=pl.List(pl.String), + ) + .list.eval( + pl.element().str.contains("ATG").arg_true().cast(pl.Int8) + ) + ) + ) + .explode("corrects") + .with_columns(corrects=pl.col("corrects") * 3 - pl.col("upstr_cut")) + .group_by(pl.exclude("corrects")) + .all() + # 2. From distances, select ATG closest to pred TIS + .with_columns( + correction=( + pl.col("corrects") + .list.get( + pl.col("corrects").list.eval(pl.element().abs()).list.arg_min() + ) + .fill_null(0) + ) + ) + # 3. Apply corrects and keep unique matches + .with_columns(TIS_idx=pl.col("TIS_idx") + pl.col("correction")) + ) + if remove_duplicates: + df = df.sort("ribotie_score", descending=True).unique( + ["transcript_id", "TIS_idx"], keep="first" + ) + prot_f_dict = { + "protein_seq": pl.String, + "TTS_on_transcript": pl.Boolean, + "stop_codon": pl.String, } - - TIS_idxs = df_out.TIS_idx.copy() - corrections = df_out.correction.copy() - for i, row in tqdm( - df_out.iterrows(), total=len(df_out), desc=f"{time()}: parsing ORF information " - ): - tr_seq = f["seq"][row.f_idx] - TIS_idx = row.TIS_idx - if correction and not np.array_equal(tr_seq[TIS_idx : TIS_idx + 3], [0, 1, 3]): - low_bound = max(0, TIS_idx - (dist * 3)) - tr_seq_win = tr_seq[low_bound : TIS_idx + (dist + 1) * 3] - atg = [0, 1, 3] - matches = [ - x - for x in range(len(tr_seq_win)) - if np.array_equal(tr_seq_win[x : x + len(atg)], atg) + df = ( + df.with_columns( + ( + pl.col("seq") + .str.slice(pl.col("TIS_idx")) + .map_elements( + lambda x: dict(zip(list(prot_f_dict.keys()), construct_prot(x))), + return_dtype=pl.Struct(prot_f_dict), + ) + .alias("prot_struct") + ), + strand=pl.col("strand"), + TIS_pos=pl.col("TIS_idx") + 1, + canonical_TTS_idx=pl.col("canonical_LTS_idx") + 1, + start_codon=pl.col("seq").str.slice(pl.col("TIS_idx"), 3), + dist_from_canonical_TIS=( + pl.when(pl.col("canonical_TIS_idx") != 0) + .then(pl.col("TIS_idx") - pl.col("canonical_TIS_idx")) + .otherwise(pl.lit(None)) + .cast(pl.Int32) + ), + canonical_TIS_pos=pl.col("canonical_TIS_idx") + 1, + canonical_TTS_pos=pl.col("canonical_TTS_idx") + 1, + canonical_LTS_pos=pl.col("canonical_LTS_idx") + 1, + ) + .unnest("prot_struct") + .with_columns( + frame_wrt_canonical_TIS=pl.col("dist_from_canonical_TIS") % 3, + ORF_len=pl.col("protein_seq").str.len_chars() * 3, + ORF_id=pl.col("transcript_id") + + "_" + + (pl.col("TIS_idx") + 1).cast(pl.String), + ) + .with_columns( + TTS_idx=( + pl.when(pl.col("TTS_on_transcript")) + .then(pl.col("TIS_idx") + pl.col("ORF_len")) + .otherwise(pl.lit(-1)) + ), + LTS_idx=( + pl.when(pl.col("TTS_on_transcript")) + .then(pl.col("TIS_idx") + pl.col("ORF_len") - 1) + .otherwise(pl.col("transcript_len") - 1) + ), + ) + .with_columns( + TTS_pos=pl.col("TTS_idx") + 1, + ) + ) + # Filter out faulty ORFs (length 0) + df = df.filter(pl.col("ORF_len") > 0) + # Find exon id's and coordinates for start and stop sites + sel_cols = [ + "ORF_id", + "strand", + "TTS_on_transcript", + "TIS_idx", + "LTS_idx", + "TTS_idx", + "exon_idxs", + "exon_coords", + ] + df_tmp = df.select(sel_cols) + for poi in ["TIS", "LTS", "TTS"]: + exon_start_idx = (pl.col(f"{poi}_exon") - 1) * 2 + # polars cannot use pl.col in list.eval function, hence explode - groupby + df_tmp = ( + df_tmp.join( + ( + df_tmp.explode("exon_idxs") + .with_columns( + (pl.col("exon_idxs") <= pl.col(f"{poi}_idx")).alias( + f"{poi}_exon" + ) + ) + .group_by("ORF_id") + .agg(pl.col(f"{poi}_exon").sum() // 2 + 1) + ), + on="ORF_id", + how="left", + ) + .with_columns( + ( + pl.when(pl.col(f"{poi}_idx") == -1) + .then(pl.lit(-1)) + .otherwise(pl.col(f"{poi}_exon")) + .alias(f"{poi}_exon") + ), + ( + pl.when(pl.col(f"{poi}_idx") == -1) + .then(pl.lit(-1)) + .otherwise( + pl.col(f"{poi}_idx") + - pl.col("exon_idxs").list.get(exon_start_idx) + ) + .alias(f"{poi}_idx_on_exon") + ), + ) + .with_columns( + (pl.col(f"{poi}_idx_on_exon") + 1).alias(f"{poi}_pos_on_exon"), + pl.when((pl.col(f"{poi}_idx") != -1) & (pl.col("strand") == "+")) + .then( + pl.col("exon_coords").list.get(exon_start_idx, null_on_oob=True) + + pl.col(f"{poi}_idx_on_exon") + ) + .when((pl.col(f"{poi}_idx") != -1) & (pl.col("strand") == "-")) + .then( + pl.col("exon_coords").list.get(exon_start_idx + 1, null_on_oob=True) + - pl.col(f"{poi}_idx_on_exon") + ) + .otherwise(pl.lit(-1)) + .alias(f"{poi}_coord"), + ) + ) + df = df.join(df_tmp.drop(sel_cols[1:]), on="ORF_id", how="left") + if has_ribo_output: + print(f"{time()}: Parsing ribo-seq information...") + df_ribo = df.select("ORF_id", "h5_idx", "TIS_idx", "TTS_idx", "ORF_len") + # multiple sets in case of merged data sets + ribo_subsets = np.array(ribo_id.decode().split("&")) + ribo_paths = [ + [ + f"{h5_path.split('.h5')[0]}_{subset}.h5", + f"transcript/riboseq/{subset}/5/", ] - matches = np.array(matches) - min(TIS_idx, dist * 3) - matches = matches[matches % 3 == 0] - if len(matches) > 0: - match = matches[np.argmin(abs(matches))] - corrections[row.name] = match - TIS_idx = TIS_idx + match - TIS_idxs[row.name] = TIS_idx - DNA_frag = vec2DNA(tr_seq[TIS_idx:]) - df_dict["start_codon"].append(DNA_frag[:3]) - prot, has_stop, stop_codon = construct_prot(DNA_frag) - df_dict["stop_codon"].append(stop_codon) - df_dict["prot"].append(prot) - df_dict["TTS_on_transcript"].append(has_stop) - df_dict["ORF_len"].append(len(prot) * 3) - TIS_exon = np.sum(TIS_idx >= row.exon_idxs) // 2 + 1 - TIS_exon_idx = TIS_idx - row.exon_idxs[(TIS_exon - 1) * 2] - if row.strand == b"+": - TIS_coord = row.exon_coords[(TIS_exon - 1) * 2] + TIS_exon_idx - else: - TIS_coord = row.exon_coords[(TIS_exon - 1) * 2 + 1] - TIS_exon_idx - if has_stop: - TTS_idx = TIS_idx + df_dict["ORF_len"][-1] - TTS_pos = TTS_idx + 1 - TTS_exon = np.sum(TTS_idx >= row.exon_idxs) // 2 + 1 - TTS_exon_idx = TTS_idx - row.exon_idxs[(TTS_exon - 1) * 2] - if row.strand == b"+": - TTS_coord = row.exon_coords[(TTS_exon - 1) * 2] + TTS_exon_idx - else: - TTS_coord = row.exon_coords[(TTS_exon - 1) * 2 + 1] - TTS_exon_idx - else: - TTS_coord, TTS_exon, TTS_pos = -1, -1, -1 - - df_dict["TIS_coord"].append(TIS_coord) - df_dict["TIS_exon"].append(TIS_exon) - df_dict["TTS_pos"].append(TTS_pos) - df_dict["TTS_exon"].append(TTS_exon) - df_dict["TTS_coord"].append(TTS_coord) - - df_out = df_out.assign(**df_dict) - df_out["TIS_idx"] = TIS_idxs - df_out["correction"] = corrections - df_out["seqname"] = df_out["contig"] - df_out["TIS_pos"] = df_out["TIS_idx"] + 1 - df_out["output"] = preds[idxs] - df_out = df_out.sort_values("output", ascending=False) - df_out["output_rank"] = np.arange(len(df_out)) - - df_out["dist_from_canonical_TIS"] = df_out["TIS_idx"] - df_out["canonical_TIS_idx"] - df_out.loc[df_out["canonical_TIS_idx"] == -1, "dist_from_canonical_TIS"] = np.nan - df_out["frame_wrt_canonical_TIS"] = df_out["dist_from_canonical_TIS"] % 3 - - if has_seq_output: - seq_out = [ - f["seq_output"][i][j] - for i, j in zip(orf_dict["f_idx"], orf_dict["TIS_idx"]) + for subset in ribo_subsets ] - orf_dict.update({"seq_output": seq_out}) - - if has_ribo_output: - ribo_subsets = np.array(ribo_id.split(b"&")) - sparse_reads_set = [] - for subset in ribo_subsets: + # only data and indices of sparse object are required (all counts are summed over read lengths) + for h in ["data", "indices", "indptr", "shape"]: if parallel: - r = h5py.File(f"{h5_path.split('.h5')[0]}_{subset.decode()}.h5")[ - "transcript" - ] - sparse_reads = h5max.load_sparse( - r[f"riboseq/{subset.decode()}/5/"], df_out["f_idx"], to_numpy=False + ribo_data = np.add.reduce( + [ + np.array(h5py.File(a)[f"{p}/{h}"])[df_ribo["h5_idx"]] + for a, p in ribo_paths + ] ) - r.file.close() else: - sparse_reads = h5max.load_sparse( - f[f"riboseq/{subset.decode()}/5/"], df_out["f_idx"], to_numpy=False + counts = np.add.reduce( + [np.array(f[f"{p}/{h}"])[df_ribo["h5_idx"]] for _, p in ribo_paths] ) - sparse_reads_set.append(sparse_reads) - sparse_reads = np.add.reduce(sparse_reads_set) - df_out["reads_in_tr"] = np.array([s.sum() for s in sparse_reads]) - reads_in = [] - reads_out = [] - in_frame_read_perc = [] - for i, (_, row) in tqdm( - enumerate(df_out.iterrows()), - total=len(df_out), - desc=f"{time()}: parsing ribo-seq information ", - ): - end_of_ORF_idx = row.TIS_pos + row.ORF_len - 1 - reads_in_ORF = sparse_reads[i][:, row.TIS_pos - 1 : end_of_ORF_idx].sum() - reads_out_ORF = sparse_reads[i].sum() - reads_in_ORF - in_frame_reads = sparse_reads[i][ - :, np.arange(row["TIS_pos"] - 1, end_of_ORF_idx, 3) - ].sum() - reads_in.append(reads_in_ORF) - reads_out.append(reads_out_ORF) - - in_frame_read_perc.append(in_frame_reads / max(reads_in_ORF, 1)) - - df_out["reads_in_ORF"] = reads_in - df_out["reads_out_ORF"] = reads_out - df_out["in_frame_read_perc"] = in_frame_read_perc - - TIS_coords = np.array(f["canonical_TIS_coord"]) - TTS_coords = np.array(f["canonical_TTS_coord"]) - cds_lens = np.array(f["canonical_TTS_idx"]) - np.array(f["canonical_TIS_idx"]) - orf_type = [] - is_cds = [] - for i, row in tqdm( - df_out.iterrows(), - total=len(df_out), - desc=f"{time()}: parsing ORF type information ", + df_ribo = df_ribo.with_columns( + pl.Series(name=h, values=list(counts), dtype=pl.List(pl.Int64)), + ) + # get in-ORF reads and properties that cannot be retrieved using polars API + csr_cols = ["data", "indices", "indptr", "shape"] + csr_f = ( + lambda x: csr_matrix( + (x["data"], x["indices"], x["indptr"]), shape=x["shape"] + ) + .sum(axis=0) + .tolist()[0] + ) + df_ribo = df_ribo.with_columns( + pl.struct(*csr_cols) + .map_elements( + csr_f, + return_dtype=pl.List(pl.Int32), + ) + .alias("reads") + ) + # get ribo properties supported by polars API + df_ribo = df_ribo.with_columns( + reads_ORF=( + pl.col("reads").list.slice(pl.col("TIS_idx"), pl.col("ORF_len")) + ), + reads_in_transcript=pl.col("data").list.sum(), + ).with_columns( + reads_in_ORF=pl.col("reads_ORF").list.sum(), + reads_in_frame_frac=( + pl.col("reads_ORF") + .list.gather_every(3) + .list.sum() + .truediv(pl.col("reads_ORF").list.sum()) + ), + reads_5UTR=(pl.col("reads").list.slice(0, pl.col("TIS_idx")).list.sum()), + reads_3UTR=( + pl.when(pl.col("TTS_idx") != -1) + .then(pl.col("reads").list.slice(pl.col("TTS_idx")).list.sum()) + .otherwise(pl.lit(0)) + ), + reads_skew=( + pl.col("reads_ORF") + .list.slice(offset=pl.col("reads_ORF").list.len().truediv(2)) + .list.sum() + .truediv(pl.col("reads_ORF").list.sum()) + .sub(0.5) + .mul(2) + ), + reads_coverage_frac=( + pl.col("reads_ORF") + .list.eval((pl.element() > 0)) + .list.sum() + .truediv(pl.col("reads_ORF").list.len()) + ), + reads_entropy=( + pl.col("reads_ORF").map_elements( + lambda x: entropy(x, np.full(len(x), 1) / len(x)), + return_dtype=pl.Float32, + ) + ), + ) + df = df.join(df_ribo[:, [0, *range(11, 19)]], on="ORF_id", how="left") + # detect ORF biotypes, evaluate whether transcript biotype is given + print(f"{time()}: Parsing ORF type information...") + if "transcript_biotype" in df.columns: + biotype_expr = pl.col("transcript_biotype") == "lncRNA" + else: + biotype_expr = pl.lit(False) + df = df.with_columns( + ORF_type=pl.when(pl.col("canonical_TIS_idx") != -1) + .then( + pl.when(pl.col("canonical_TIS_idx") == pl.col("TIS_idx")) + .then( + pl.when(pl.col("canonical_LTS_idx") == pl.col("LTS_idx")) + .then(pl.lit("annotated CDS")) + .when(pl.col("canonical_TTS_idx") < pl.col("TTS_idx")) + .then(pl.lit("C-terminal extension")) + .otherwise(pl.lit("C-terminal truncation")) + ) + .when(pl.col("canonical_TTS_idx") < pl.col("TIS_idx")) + .then(pl.lit("dORF")) + .when(pl.col("canonical_TIS_idx") > pl.col("TTS_idx")) + .then(pl.lit("uORF")) + .when(pl.col("canonical_TIS_idx") > pl.col("TIS_idx")) + .then( + pl.when(pl.col("canonical_TTS_idx") == pl.col("TTS_idx")) + .then(pl.lit("N-terminal extension")) + .otherwise(pl.lit("uoORF")) + ) + .when(pl.col("canonical_TTS_idx") < pl.col("TTS_idx")) + .then(pl.lit("doORF")) + .otherwise( + pl.when(pl.col("canonical_TTS_idx") == pl.col("TTS_idx")) + .then(pl.lit("N-terminal truncation")) + .otherwise(pl.lit("intORF")) + ) + ) + .otherwise( + pl.when(biotype_expr) + .then(pl.lit("lncRNA-ORF")) + .otherwise(pl.lit("varRNA-ORF")) + ) + ) + print(f"{time()}: Detecting CDS variants...") + out_cols = ["ORF_id", "ORF_coords", "ORF_exons"] + df = ( + df.join( + ( + df.select("ORF_id", "TIS_coord", "LTS_coord", "strand", "exon_coords") + .map_rows(lambda x: (x[0], *transcript_region_to_exons(*x[1:]))) + .rename({f"column_{i}": n for i, n in enumerate(out_cols)}) + ), + on="ORF_id", + ) + .with_columns( + ORF_exon_start=pl.col("ORF_coords").list.gather_every(2, 0), + ORF_exon_end=pl.col("ORF_coords").list.gather_every(2, 1), + ) + .with_columns( + ORF_exon_len=( + (pl.col("ORF_exon_start") - pl.col("ORF_exon_end")).list.eval( + pl.element().abs() + 1 + ) + ) + ) + ) + # load in all CDS properties in h5 db + h5_cols = [ + "transcript_id", + "seqname", + "strand", + "CDS_coords", + "canonical_TIS_coord", + "canonical_LTS_coord", + ] + mask = pl.Series(list(f[f"transcript/canonical_TIS_idx"])) != -1 + df_CDS = ( + pl.DataFrame( + {f"{h}": np.array(f[f"transcript/{h}"])[mask.arg_true()] for h in h5_cols} + ) + .with_columns( + pl.col("CDS_coords").map_elements(list, pl.List(pl.Int64)), + pl.col(pl.Binary).cast(pl.String), + ) + .with_columns( + CDS_exon_start=pl.col("CDS_coords").list.gather_every(2, 0), + CDS_exon_end=pl.col("CDS_coords").list.gather_every(2, 1), + CDS_start_range=( + pl.when(pl.col("strand") == "+") + .then(pl.col("CDS_coords").list.get(0)) + .otherwise(pl.col("CDS_coords").list.get(-2)) + ), + CDS_end_range=( + pl.when(pl.col("strand") == "+") + .then(pl.col("CDS_coords").list.get(-1)) + .otherwise(pl.col("CDS_coords").list.get(1)) + ), + ) + .drop("CDS_coords") + ) + # close h5 db handle + f.file.close() + # To evaluate CDS variants, filter df and df_CDS by seqname (to prevent OOM) + df_grps = [] + for seqname, df_grp in tqdm( + df.group_by("seqname"), total=df["seqname"].unique().len(), desc="seqname" ): - TIS_mask = row["TIS_coord"] == TIS_coords - TTS_mask = row["TTS_coord"] == TTS_coords - len_mask = row.ORF_len == cds_lens - is_cds.append(np.logical_and.reduce([TIS_mask, TTS_mask, len_mask]).any()) - - if row["canonical_TIS_idx"] != -1: - if row["canonical_TIS_idx"] == row["TIS_pos"] - 1: - orf_type.append("annotated CDS") - elif row["TIS_pos"] > row["canonical_TTS_idx"] + 1: - orf_type.append("dORF") - elif row["TTS_pos"] < row["canonical_TIS_idx"] + 1: - orf_type.append("uORF") - elif row["TIS_pos"] < row["canonical_TIS_idx"] + 1: - if row["TTS_pos"] == row["canonical_TTS_idx"] + 1: - orf_type.append("N-terminal extension") - else: - orf_type.append("uoORF") - elif row["TTS_pos"] > row["canonical_TTS_idx"] + 1: - orf_type.append("doORF") - else: - if row["TTS_pos"] == row["canonical_TTS_idx"] + 1: - orf_type.append("N-terminal truncation") - else: - orf_type.append("intORF") - else: - shares_TIS_coord = row["TIS_coord"] in TIS_coords - shares_TTS_coord = row["TTS_coord"] in TTS_coords - if shares_TIS_coord or shares_TTS_coord: - orf_type.append("CDS variant") + df_CDS_grp = df_CDS.filter(pl.col("seqname") == seqname[0]) + # detect CDS variant information + df_grp = df_grp.with_columns( + has_CDS_TIS=(pl.col("TIS_coord").is_in(df_CDS_grp["canonical_TIS_coord"])), + has_CDS_TTS=(pl.col("LTS_coord").is_in(df_CDS_grp["canonical_LTS_coord"])), + ) + sel_cols = [ + "ORF_id", + "CDS_exon_start", + "CDS_exon_end", + "ORF_exon_start", + "ORF_exon_end", + ] + var_feats = { + "ORF_id": [], + "shared_in_frame_CDS_region": [], + "shared_in_frame_CDS_frac": [], + "ORF_coords_no_CDS": [], + "has_CDS_clones": [], + } + for r in df_grp.iter_rows(named=True): + var_feats["ORF_id"].append(r["ORF_id"]) + df_CDS_row = df_CDS_grp.filter( + pl.when(pl.col("strand") == "+") + .then( + (pl.col("CDS_start_range") < r["ORF_exon_end"][-1]) + & (pl.col("CDS_end_range") > r["ORF_exon_start"][0]) + ) + .otherwise( + (pl.col("CDS_start_range") < r["ORF_exon_end"][0]) + & (pl.col("CDS_end_range") > r["ORF_exon_start"][-1]) + ) + ) + if len(df_CDS_row) > 0: + df_CDS_row = df_CDS_row.with_columns( + ORF_id=pl.lit(r["ORF_id"]), + ORF_exon_start=pl.lit(r["ORF_exon_start"]), + ORF_exon_end=pl.lit(r["ORF_exon_end"]), + ORF_exon_len=pl.lit(r["ORF_exon_len"]), + ) + has_CDS_clones = ( + (df_CDS_row["CDS_exon_start"] == df_CDS_row["ORF_exon_start"]) + & (df_CDS_row["CDS_exon_end"] == df_CDS_row["ORF_exon_end"]) + ).any() + # ORF start is within CDS exon boundaries + cond_1 = (pl.col("ORF_exon_start") >= pl.col("CDS_exon_start")) & ( + pl.col("ORF_exon_start") < pl.col("CDS_exon_end") + ) + # ORF end is within CDS exon boundaries + cond_2 = (pl.col("ORF_exon_end") > pl.col("CDS_exon_start")) & ( + pl.col("ORF_exon_end") <= pl.col("CDS_exon_end") + ) + cond_3 = ( + pl.when(pl.col("strand") == "+") + .then( + (pl.col("ORF_exon_start") - pl.col("CDS_exon_start")) % 3 == 0 + ) + .otherwise( + (pl.col("ORF_exon_end") - pl.col("CDS_exon_end")) % 3 == 0 + ) + ) + # filter specific exon regions if they're not overlapping with CDS exon regions + df_CDS_exons = ( + df_CDS_row.explode(["CDS_exon_start", "CDS_exon_end"]) + .explode(["ORF_exon_start", "ORF_exon_end", "ORF_exon_len"]) + .with_columns(shared_in_frame_CDS_region=(cond_1 | cond_2) & cond_3) + .filter(pl.col("shared_in_frame_CDS_region")) + ) + out = eval_overlap(*[df_CDS_exons[c] for c in sel_cols]) + in_frame_CDSs = ( + df_CDS_exons.group_by("transcript_id") + .agg(pl.col("shared_in_frame_CDS_region").any()) + .select("transcript_id") + .to_series() + .to_list() + ) else: - orf_type.append("other") - df_out["ORF_type"] = orf_type - df_out["ORF_equals_CDS"] = is_cds - df_out.loc[df_out["tr_biotype"] == b"lncRNA", "ORF_type"] = "lncRNA-ORF" - # decode strs - for header in DECODE: - df_out[header] = df_out[header].str.decode("utf-8") - df_out["ORF_id"] = df_out["tr_id"] + "_" + df_out["TIS_pos"].astype(str) - # re-arrange columns - o_headers = [h for h in OUT_HEADERS if h in df_out.columns] - df_out = df_out.loc[:, o_headers].sort_values("output_rank") - # remove duplicates - if correction and remove_duplicates: - df_out = df_out.drop_duplicates("ORF_id") - if exclude_invalid_TTS: - df_out = df_out[df_out["TTS_on_transcript"]] - df_out = df_out[df_out["ORF_len"] > min_ORF_len] - df_out = df_out[df_out["start_codon"].str.contains(start_codons)] - df_out.to_csv(f"{out_prefix}.csv", index=None) - f.file.close() + has_CDS_clones = False + out = [0, []] + in_frame_CDSs = [] + var_feats["has_CDS_clones"].append(has_CDS_clones) + var_feats["shared_in_frame_CDS_frac"].append(out[0]) + var_feats["ORF_coords_no_CDS"].append(out[1]) + var_feats["shared_in_frame_CDS_region"].append(in_frame_CDSs) + df_grp = df_grp.join(pl.DataFrame(var_feats), on="ORF_id") + df_grps.append(df_grp) + df = pl.concat(df_grps).with_columns( + pl.col("shared_in_frame_CDS_frac").truediv(pl.col("ORF_len")) + ) + # Filter CDS variants and custom filters + conds_xtr = [ + pl.col("TTS_on_transcript") if exclude_invalid_TTS else pl.lit(True), + pl.col("start_codon").str.contains(start_codons), + pl.col("ORF_len") >= min_ORF_len, + ] + conds_cds_var = [ + pl.col("has_CDS_clones") == False, + pl.col("shared_in_frame_CDS_frac") < 1, + ] + c_xtr = pl.lit(True).and_(*conds_xtr) + c_clone = pl.col("has_CDS_clones") == False + c_cds_var = pl.lit(True).and_(*conds_cds_var) + c_1 = pl.col("ORF_type") == "annotated CDS" + c_2 = pl.col("ORF_type").is_in( + [ + "N-terminal truncation", + "N-terminal extension", + "C-terminal trunctation", + "C-terminal extension", + ] + ) + c_3 = pl.col("ORF_type").is_in( + ["uORF", "uoORF", "dORF", "doORF", "intORF", "lncRNA-ORF"] + ) + if "transcript_biotype" in df.columns: + c_bio = pl.col("transcript_biotype") == "protein_coding" + else: + c_bio = pl.lit(False) + if not unfiltered: + filter_suffix = "" + df_filts = [] + for _, df_grp in df.group_by("TIS_coord"): + df_filt = df_grp.filter( + pl.when((c_1 & c_xtr).any()) + .then(c_1 & c_xtr) + .when((c_2 & c_xtr & c_clone).any()) + .then(c_2 & c_xtr & c_clone) + .when((c_3 & c_xtr & c_cds_var).any()) + .then(c_3 & c_xtr & c_cds_var) + .otherwise(c_xtr & c_cds_var) + & pl.when((c_1 | c_bio).any()).then(c_1 | c_bio).otherwise(pl.lit(True)) + ) + df_filts.append(df_filt) + df_out = pl.concat(df_filts) + else: + filter_suffix = ".unfiltered" + df_out = df + df_out = ( + df_out.with_columns( + (pl.col(f"{prefix}score").rank(method="ordinal", descending=True)).alias( + f"{prefix}rank" + ) + ) + .select(out_headers) + .sort(f"{prefix}rank") + .rename(RENAME_HEADERS) + ) + df_out.write_csv(f"{out_prefix}{filter_suffix}.csv") return df_out @@ -447,8 +702,8 @@ def construct_output_table( def process_seq_preds(ids, preds, seqs, min_prob): df = pd.DataFrame( columns=[ - "ID", - "tr_len", + "transcript_id", + "transcript_len", "TIS_pos", "output", "start_codon", @@ -488,27 +743,36 @@ def create_multiqc_reports(df, out_prefix): with open(output, "w") as f: f.write(RIBOTIE_MQC_HEADER) f.write(START_CODON_MQC_HEADER) - df.start_codon.value_counts().to_csv(output, sep="\t", header=False, mode="a") + start_codons = df["start_codon"].value_counts() + with open(output, mode="a") as f: + start_codons.write_csv(f, separator="\t", include_header=False) # Transcript biotypes - output = out_prefix + ".biotypes_variant_mqc.tsv" - with open(output, "w") as f: - f.write(RIBOTIE_MQC_HEADER) - f.write(BIOTYPE_VARIANT_MQC_HEADER) - orf_biotypes = pd.Series(index=ORF_BIOTYPE_ORDER, data=0) - counts = df[df.ORF_type == "CDS variant"].tr_biotype.value_counts() - orf_biotypes[counts.index] = counts - orf_biotypes.to_csv(output, sep="\t", header=False, mode="a") + if "transcript_biotype" in df.columns: + output = out_prefix + ".biotypes_variant_mqc.tsv" + with open(output, "w") as f: + f.write(RIBOTIE_MQC_HEADER) + f.write(BIOTYPE_VARIANT_MQC_HEADER) + orf_biotypes = ( + df.filter(pl.col("ORF_type") == "varRNA-ORF")["transcript_biotype"] + .value_counts() + .sort(pl.col("transcript_biotype").cast(pl.Enum(ORF_BIOTYPE_ORDER))) + ) + with open(output, mode="a") as f: + orf_biotypes.write_csv(f, separator="\t", include_header=False) # ORF types output = out_prefix + ".ORF_types_mqc.tsv" with open(output, "w") as f: f.write(RIBOTIE_MQC_HEADER) f.write(ORF_TYPE_MQC_HEADER) - orf_types = pd.Series(index=ORF_TYPE_ORDER, data=0) - counts = df.ORF_type.value_counts() - orf_types[counts.index] = counts - orf_types.to_csv(output, sep="\t", header=False, mode="a") + orf_types = ( + df["ORF_type"] + .value_counts() + .sort(pl.col("ORF_type").cast(pl.Enum(ORF_TYPE_ORDER))) + ) + with open(output, mode="a") as f: + orf_types.write_csv(f, separator="\t", include_header=False) # ORF lengths # output = out_prefix + ".ORF_lens_mqc.tsv" @@ -528,47 +792,43 @@ def csv_to_gtf(h5_path, df, out_prefix, exclude_annotated=False): if exclude_annotated: df = df.filter(pl.col("ORF_type") != "annotated CDS") df = df.fill_null("NA") - df = df.sort("tr_id") + df = df.sort("transcript_id") f = h5py.File(h5_path, "r") - f_ids = np.array(f["transcript/id"]) + f_ids = np.array(f["transcript/transcript_id"]) # fast id mapping xsorted = np.argsort(f_ids) - pred_to_h5_args = xsorted[np.searchsorted(f_ids[xsorted], df["tr_id"])] + pred_to_h5_args = xsorted[np.searchsorted(f_ids[xsorted], df["transcript_id"])] # obtain exons exons_coords = np.array(f["transcript/exon_coords"])[pred_to_h5_args] f.close() gff_parts = [] - for tis, stop_codon_start, strand, exon_coord in zip( - df["TIS_coord"], df["TTS_coord"], df["strand"], exons_coords + for TIS, LTS, TTS, strand, exon_coord in zip( + df["TIS_coord"], df["LTS_coord"], df["TTS_coord"], df["strand"], exons_coords ): - start_codon_stop = find_distant_exon_coord(tis, 2, strand, exon_coord) + start_codon_stop = find_distant_exon_coord(TIS, 2, strand, exon_coord) start_parts, start_exons = transcript_region_to_exons( - tis, start_codon_stop, strand, exon_coord + TIS, start_codon_stop, strand, exon_coord ) # acquire cds stop coord from stop codon coord. - if stop_codon_start != -1: - stop_codon_stop = find_distant_exon_coord( - stop_codon_start, 2, strand, exon_coord - ) + if TTS != -1: + stop_codon_stop = find_distant_exon_coord(TTS, 2, strand, exon_coord) stop_parts, stop_exons = transcript_region_to_exons( - stop_codon_start, stop_codon_stop, strand, exon_coord + TTS, stop_codon_stop, strand, exon_coord ) - tts = find_distant_exon_coord(stop_codon_start, -1, strand, exon_coord) else: stop_parts, stop_exons = np.empty(start_parts.shape), np.empty( start_exons.shape ) - tts = -1 - cds_parts, cds_exons = transcript_region_to_exons(tis, tts, strand, exon_coord) + cds_parts, cds_exons = transcript_region_to_exons(TIS, LTS, strand, exon_coord) tr_coord = np.array([exon_coord[0], exon_coord[-1]]) exons = np.arange(1, len(exon_coord) // 2 + 1) coords_packed = np.vstack( [ tr_coord.reshape(-1, 2), exon_coord.reshape(-1, 2), - start_parts.reshape(-1, 2), - cds_parts.reshape(-1, 2), - stop_parts.reshape(-1, 2), + np.array(start_parts).reshape(-1, 2), + np.array(cds_parts).reshape(-1, 2), + np.array(stop_parts).reshape(-1, 2), ] ).astype(int) exons_packed = np.hstack( @@ -590,10 +850,10 @@ def csv_to_gtf(h5_path, df, out_prefix, exclude_annotated=False): property_list = [ f'gene_id "{row["gene_id"]}', f'transcript_id "{row["ORF_id"]}', - f'gene_name "{row["gene_name"]}', - f'transcript_biotype "{row["tr_biotype"]}', - f'tag "{row["tr_tag"]}', - f'transcript_support_level "{row["tr_support_lvl"]}', + # f'gene_name "{row["gene_name"]}', + # f'transcript_biotype "{row["transcript_biotype"]}', + # f'tag "{row["tag"]}', + # f'transcript_support_level "{row["tr_support_lvl"]}', ] if feature not in ["transcript"]: property_list.insert( @@ -601,10 +861,11 @@ def csv_to_gtf(h5_path, df, out_prefix, exclude_annotated=False): f'exon_number "{exon}', ) if feature not in ["transcript", "exon"]: - property_list.insert( - 3, - f'ORF_id "{row["ORF_id"]}", model_output "{row["output"]}"; ORF_type "{row["ORF_type"]}', + entries = np.array( + ["ORF_id", "ORF_type", "ribotie_score", "tis_transformer_score"] ) + entries = entries[np.isin(entries, df.columns)] + property_list.insert(3, "; ".join([f'{a} "{row[a]}"' for a in entries])) properties = '"; '.join(property_list) gtf_lines.append( "\t".join( diff --git a/transcript_transformer/ribotie.py b/transcript_transformer/ribotie.py index fd185cd..f6571cb 100644 --- a/transcript_transformer/ribotie.py +++ b/transcript_transformer/ribotie.py @@ -1,7 +1,6 @@ import os import sys import numpy as np -import polars as pl import yaml import h5py from importlib import resources as impresources @@ -9,7 +8,7 @@ from .transcript_transformer import train, predict from .argparser import Parser -from .pretrained import riboformer_models +from .pretrained import ribotie_models from . import configs from .data import process_seq_data, process_ribo_data from .processing import construct_output_table, csv_to_gtf, create_multiqc_reports @@ -22,7 +21,7 @@ def parse_args(): data_parser.add_argument( "--prob_cutoff", type=float, - default=0.15, + default=0.125, help="Determines the minimum model output score required for model " "predictions to be included in the result table.", ) @@ -64,6 +63,11 @@ def parse_args(): action="store_true", help="Include annotated CDS regions in generated GTF file containing predicted translated ORFs.", ) + data_parser.add_argument( + "--unfiltered", + action="store_true", + help="Don't apply filtering of top predictions", + ) # data_parser.add_argument( # "--pretrained_model", # type=json.loads, @@ -84,7 +88,7 @@ def parse_args(): parser.add_comp_args() parser.add_training_args() parser.add_train_loading_args(pretrain=True, auto=True) - default_config = f"{impresources.files(configs) / 'riboformer_defaults.yml'}" + default_config = f"{impresources.files(configs) / 'ribotie_defaults.yml'}" args = parser.parse_arguments(sys.argv[1:], [default_config]) if args.out_prefix is None: args.out_prefix = f"{os.path.splitext(args.conf[0])[0]}_" @@ -176,9 +180,9 @@ def main(): if not (args.data or args.results or args.pretrain): if "pretrained_model" not in args: args = load_args( - (impresources.files(riboformer_models) / "50perc_06_23.yml"), args + (impresources.files(ribotie_models) / "50perc_06_23.yml"), args ) - args.model_dir = str(impresources.files(riboformer_models)._paths[0]) + args.model_dir = str(impresources.files(ribotie_models)._paths[0]) for i, ribo_set in enumerate(args.ribo_ids): args_set = deepcopy(args) args_set.ribo_ids = [ribo_set] @@ -221,13 +225,14 @@ def main(): exclude_invalid_TTS=not args.include_invalid_TTS, ribo=out, parallel=args.parallel, + unfiltered=args.unfiltered, ) if df is not None: - csv_to_gtf( - args.h5_path, pl.from_pandas(df), out_prefix, args.exclude_annotated - ) + csv_to_gtf(args.h5_path, df, out_prefix, args.exclude_annotated) os.makedirs(f"{args.out_prefix}/multiqc", exist_ok=True) - create_multiqc_reports(df, f"{args.out_prefix}/multiqc/{ribo_set_str}") + create_multiqc_reports( + df, f"{os.path.dirname(args.out_prefix)}/multiqc/{ribo_set_str}" + ) def merge_outputs(prefix, keys): diff --git a/transcript_transformer/tis_transformer.py b/transcript_transformer/tis_transformer.py index 6293e0d..ed564bb 100644 --- a/transcript_transformer/tis_transformer.py +++ b/transcript_transformer/tis_transformer.py @@ -40,7 +40,8 @@ def parse_args(): default_config = f"{impresources.files(configs) / 'tis_transformer_defaults.yml'}" args = parser.parse_arguments(sys.argv[1:], [default_config]) if args.out_prefix is None: - args.out_prefix = f"{os.path.splitext(args.conf[0])[0]}_" + print(args.conf[0]) + args.out_prefix = f"{os.path.splitext(args.h5_path)[0]}_" assert ~args.results and ~args.data, ( "cannot only do processing of data and results, disable either" " --data_process or --result_process" @@ -66,8 +67,8 @@ def main(): args.input_type = "hdf5" # determine optimal allocation of seqnames to train/val/test set f = h5py.File(args.h5_path, "r")["transcript"] - contigs = np.array(f["contig"]) - tr_lens = np.array(f["tr_len"]) + contigs = np.array(f["seqname"]) + tr_lens = np.array(f["transcript_len"]) f.file.close() # determine nt count per seqname contig_set = np.unique(contigs) @@ -85,7 +86,7 @@ def main(): f = h5py.File(args.h5_path, "a") grp = f["transcript"] - f_tr_ids = np.array(grp["id"]) + f_tr_ids = np.array(grp["transcript_id"]) xsorted = np.argsort(f_tr_ids) out = np.load(f"{prefix}.npy", allow_pickle=True) tr_ids = np.hstack([o[0] for o in out]) @@ -96,12 +97,12 @@ def main(): for idx, (_, pred, _) in zip(pred_to_h5_args, out): pred_arr[idx] = pred dtype = h5py.vlen_dtype(np.dtype("float32")) - if "seq_output" in grp.keys(): + if "tis_transformer_score" in grp.keys(): print("--> Overwriting results in local h5 database...") - del grp["seq_output"] + del grp["tis_transformer_score"] else: print("--> Writing results to local h5 database...") - grp.create_dataset("seq_output", data=pred_arr, dtype=dtype) + grp.create_dataset("tis_transformer_score", data=pred_arr, dtype=dtype) f.close() if not args.no_backup: if not args.backup_path: @@ -109,12 +110,12 @@ def main(): if os.path.isfile(args.backup_path): f = h5py.File(args.backup_path, "a") grp = f["transcript"] - if "seq_output" in grp.keys(): + if "tis_transformer_score" in grp.keys(): print("--> Overwriting results in backup h5 database...") - del grp["seq_output"] + del grp["tis_transformer_score"] else: print("--> Writing results to backup h5 database...") - grp.create_dataset("seq_output", data=pred_arr, dtype=dtype) + grp.create_dataset("tis_transformer_score", data=pred_arr, dtype=dtype) f.close() if not args.data: f = h5py.File(args.h5_path, "r") diff --git a/transcript_transformer/transcript_loader.py b/transcript_transformer/transcript_loader.py index 14e654f..4df05dd 100644 --- a/transcript_transformer/transcript_loader.py +++ b/transcript_transformer/transcript_loader.py @@ -1,10 +1,12 @@ import h5py import numpy as np import torch -from h5max import load_sparse +from h5max import load_sparse_matrix from torch.utils.data import DataLoader import pytorch_lightning as pl +from pdb import set_trace + def collate_fn(batch): """ @@ -56,7 +58,13 @@ def collate_fn(batch): else: arr = np.empty(shape=(len(batch[1]), max_len + 2), dtype=int) for i, (x, l) in enumerate(zip(batch[1], max_len - lens)): - arr[i] = np.concatenate(([5], x[k], [6], [7] * l)) + try: + arr[i] = np.concatenate(([5], x[k], [6], [7] * l)) + except: + print(k, len(x[k]), l) + print(x[k]) + print(batch[0][i]) + set_trace() x_dict[k] = torch.LongTensor(arr) x_dict.update({"x_id": batch[0], "y": y_b}) @@ -186,13 +194,13 @@ def __init__( def setup(self, stage=None): f = h5py.File(self.h5_path, "r")[self.exp_path] self.seqn_list = np.array(f[self.seqn_path]) - self.transcript_lens = np.array(f["tr_len"]) + self.transcript_lens = np.array(f["transcript_len"]) # evaluate conditions # Identical mask over the samples applied to all datasets global_masks = [] for key, cond_f in self.cond["global"].items(): mask = cond_f(np.array(f[key])) - if (key != "tr_len") and (self.leaky_frac > 0): + if (key != "transcript_len") and (self.leaky_frac > 0): prob_mask = np.random.uniform(size=len(mask)) > (1 - self.leaky_frac) mask[prob_mask] = True global_masks.append(mask) @@ -220,7 +228,7 @@ def setup(self, stage=None): mask = cond_f(grouped_feature) # if data filtering is not based on transcript length, # allow leaky filtering - if (key != "tr_len") and (self.leaky_frac > 0): + if (key != "transcript_len") and (self.leaky_frac > 0): prob_mask = np.random.uniform(size=len(mask)) > ( 1 - self.leaky_frac ) @@ -467,9 +475,11 @@ def __getitem__(self, index): id_prefix = "&".join(ribo_set) + "|" ribo_path = f"riboseq/{ribo_id}/5" if self.parallel: - x = load_sparse(self.r[ribo_id][ribo_path], idx, format="csr").T + x = load_sparse_matrix( + self.r[ribo_id][ribo_path], idx, format="csr" + ).T else: - x = load_sparse(self.f[ribo_path], idx, format="csr").T + x = load_sparse_matrix(self.f[ribo_path], idx, format="csr").T if self.offsets is not None: for col_i, (_, shift) in enumerate( self.offsets[ribo_id].items() diff --git a/transcript_transformer/util_functions.py b/transcript_transformer/util_functions.py index 760a629..1699326 100644 --- a/transcript_transformer/util_functions.py +++ b/transcript_transformer/util_functions.py @@ -1,75 +1,7 @@ from datetime import datetime import numpy as np import heapq - - -cdn_prot_dict = { - "ATA": "I", - "ATC": "I", - "ATT": "I", - "ATG": "M", - "ACA": "T", - "ACC": "T", - "ACG": "T", - "ACT": "T", - "AAC": "N", - "AAT": "N", - "AAA": "K", - "AAG": "K", - "AGC": "S", - "AGT": "S", - "AGA": "R", - "AGG": "R", - "CTA": "L", - "CTC": "L", - "CTG": "L", - "CTT": "L", - "CCA": "P", - "CCC": "P", - "CCG": "P", - "CCT": "P", - "CAC": "H", - "CAT": "H", - "CAA": "Q", - "CAG": "Q", - "CGA": "R", - "CGC": "R", - "CGG": "R", - "CGT": "R", - "GTA": "V", - "GTC": "V", - "GTG": "V", - "GTT": "V", - "GCA": "A", - "GCC": "A", - "GCG": "A", - "GCT": "A", - "GAC": "D", - "GAT": "D", - "GAA": "E", - "GAG": "E", - "GGA": "G", - "GGC": "G", - "GGG": "G", - "GGT": "G", - "TCA": "S", - "TCC": "S", - "TCG": "S", - "TCT": "S", - "TTC": "F", - "TTT": "F", - "TTA": "L", - "TTG": "L", - "TAC": "Y", - "TAT": "Y", - "TAA": "_", - "TAG": "_", - "TGC": "C", - "TGT": "C", - "TGA": "_", - "TGG": "W", - "NNN": "_", -} +from transcript_transformer import CDN_PROT_DICT, PROT_IDX_DICT, DNA_IDX_DICT def construct_prot(seq): @@ -91,13 +23,12 @@ def construct_prot(seq): if "N" in cdn: string += "_" else: - string += cdn_prot_dict[cdn] + string += CDN_PROT_DICT[cdn] return string, has_stop, stop_codon -def DNA2vec(dna_seq): - seq_dict = {"A": 0, "T": 1, "U": 1, "C": 2, "G": 3, "N": 4} +def DNA2vec(dna_seq, seq_dict=DNA_IDX_DICT): dna_vec = np.zeros(len(dna_seq), dtype=int) for idx in np.arange(len(dna_seq)): dna_vec[idx] = seq_dict[dna_seq[idx]] @@ -105,12 +36,28 @@ def DNA2vec(dna_seq): return dna_vec +def prot2vec(prot_seq, prot_dict=PROT_IDX_DICT): + prot_vec = np.zeros(len(prot_seq), dtype=int) + for idx in np.arange(len(prot_seq)): + prot_vec[idx] = prot_dict[prot_seq[idx]] + + return list(prot_vec) + + +def listify(array): + return [list(a) for a in array] + + def time(): return datetime.now().strftime("%H:%M:%S %m-%d ") -def vec2DNA(tr_seq, np_dict=np.array(["A", "T", "C", "G", "N"])): - return "".join(np_dict[tr_seq]) +def vec2DNA(vec, np_dict=np.array(["A", "T", "C", "G", "N"])): + return "".join(np_dict[vec]) + + +def vec2prot(vec, np_dict=np.array(list(PROT_IDX_DICT.keys()))): + return "".join(np_dict[vec]) def divide_keys_by_size(size_dict, num_chunks): @@ -223,8 +170,8 @@ def transcript_region_to_exons( stop coordinates must exist on exons. Args: - start_coord (str): start coordinate - stop_coord (str): stop coordinate, NOT start of stop codon for CDSs + start_coord (int) start coordinate + stop_coord (int): stop coordinate, NOT start of stop codon for CDSs strand (str): strand, either + or - exons (list): list of exon bound coordinates following gtf file conventions. E.g. positive strand: [1 2 4 5] negative strand: [4 5 1 2] @@ -235,6 +182,8 @@ def transcript_region_to_exons( E.g. [{start_coord} 2 4 {stop_coord}] """ pos_strand = strand == "+" + if type(exons) == list: + exons = np.array(exons) if stop_coord == -1: if pos_strand: stop_coord = exons[-1] @@ -263,7 +212,7 @@ def transcript_region_to_exons( exon_idx_stop = len(exons) + 1 exon_numbers = np.arange(exon_idx_start // 2, exon_idx_stop // 2) - return genome_parts, exon_numbers + 1 + return list(genome_parts), list(exon_numbers + 1) def get_exon_dist_map(tr_regions, strand): @@ -333,3 +282,39 @@ def find_distant_exon_coord(ref_coord, distance, strand, exons): dist_coord = -1 return dist_coord + + +def get_str2str_idx_map(source, dest): + xsorted = np.argsort(dest) + return xsorted[np.searchsorted(dest[xsorted], source)] + + +def co_to_idx(start, end): + return start - 1, end + + +def slice_gen( + seq, + start, + end, + strand, + co=True, + to_vec=True, + seq_dict={"A": 0, "T": 1, "C": 2, "G": 3, "N": 4}, + comp_dict={0: 1, 1: 0, 2: 3, 3: 2, 4: 4}, +): + """get sequence following gtf-coordinate system""" + if co: + start, end = co_to_idx(start, end) + sl = seq[start:end].seq + + if to_vec: + sl = list(map(lambda x: seq_dict[x.upper()], sl)) + + if strand in ["-", -1, False]: + if comp_dict is not None: + sl = list(map(lambda x: comp_dict[x], sl))[::-1] + else: + sl = sl[::-1] + + return np.array(sl)