Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
jdcla authored and jdcla committed Oct 25, 2024
1 parent b842a36 commit aee76cb
Showing 13 changed files with 1,305 additions and 784 deletions.
276 changes: 276 additions & 0 deletions transcript_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Define global variables

CDN_PROT_DICT = {
"ATA": "I",
"ATC": "I",
"ATT": "I",
"ATG": "M",
"ACA": "T",
"ACC": "T",
"ACG": "T",
"ACT": "T",
"AAC": "N",
"AAT": "N",
"AAA": "K",
"AAG": "K",
"AGC": "S",
"AGT": "S",
"AGA": "R",
"AGG": "R",
"CTA": "L",
"CTC": "L",
"CTG": "L",
"CTT": "L",
"CCA": "P",
"CCC": "P",
"CCG": "P",
"CCT": "P",
"CAC": "H",
"CAT": "H",
"CAA": "Q",
"CAG": "Q",
"CGA": "R",
"CGC": "R",
"CGG": "R",
"CGT": "R",
"GTA": "V",
"GTC": "V",
"GTG": "V",
"GTT": "V",
"GCA": "A",
"GCC": "A",
"GCG": "A",
"GCT": "A",
"GAC": "D",
"GAT": "D",
"GAA": "E",
"GAG": "E",
"GGA": "G",
"GGC": "G",
"GGG": "G",
"GGT": "G",
"TCA": "S",
"TCC": "S",
"TCG": "S",
"TCT": "S",
"TTC": "F",
"TTT": "F",
"TTA": "L",
"TTG": "L",
"TAC": "Y",
"TAT": "Y",
"TAA": "_",
"TAG": "_",
"TGC": "C",
"TGT": "C",
"TGA": "_",
"TGG": "W",
"NNN": "_",
}

PROT_IDX_DICT = {
"A": 0,
"R": 1,
"N": 2,
"D": 3,
"C": 4,
"E": 5,
"Q": 6,
"G": 7,
"H": 8,
"O": 9,
"I": 10,
"L": 11,
"K": 12,
"M": 13,
"F": 14,
"P": 15,
"S": 16,
"T": 17,
"W": 18,
"Y": 19,
"V": 20,
}

DNA_IDX_DICT = {
"A": 0,
"T": 1,
"C": 2,
"G": 3,
"N": 4,
}

IDX_PROT_DICT = {v: k for k, v in PROT_IDX_DICT.items()}
IDX_DNA_DICT = {v: k for k, v in DNA_IDX_DICT.items()}

STANDARD_HEADERS = [
"CDS_coords",
"CDS_idxs",
"canonical_TIS_coord",
"canonical_TIS_exon",
"canonical_TIS_idx",
"canonical_LTS_coord",
"canonical_LTS_idx",
"canonical_TTS_coord",
"canonical_TTS_idx",
"canonical_protein_seq",
"exon_coords",
"exon_idxs",
"gene_id",
"has_annotated_start_codon",
"has_annotated_stop_codon",
"seq",
"seqname",
"source",
"strand",
"transcript_id",
"transcript_len",
]

RENAME_HEADERS = {
"has_annotated_start_codon": "CDS_has_annotated_start_codon",
"has_annotated_stop_codon": "CDS_has_annotated_stop_codon",
}

STANDARD_OUT_HEADERS = [
"seqname",
"ORF_id",
"ORF_len",
"transcript_id",
"transcript_len",
"start_codon",
"stop_codon",
"strand",
"ORF_type",
"TIS_pos",
"TTS_pos",
"has_CDS_clones",
"has_CDS_TIS",
"has_CDS_TTS",
"shared_in_frame_CDS_frac",
"dist_from_canonical_TIS",
"frame_wrt_canonical_TIS",
"TTS_on_transcript",
"TIS_coord",
"TIS_exon",
"TTS_coord",
"TTS_exon",
"LTS_coord",
"LTS_exon",
"gene_id",
"canonical_TIS_coord",
"canonical_TIS_pos",
"canonical_LTS_coord",
"canonical_LTS_pos",
"canonical_TTS_coord",
"canonical_TTS_pos",
"has_annotated_start_codon",
"has_annotated_stop_codon",
"protein_seq",
]

RIBO_OUT_HEADERS = [
"correction",
"reads_in_transcript",
"reads_in_ORF",
"reads_in_frame_frac",
"reads_5UTR",
"reads_3UTR",
"reads_coverage_frac",
"reads_entropy",
"reads_skew",
]


RIBOTIE_MQC_HEADER = """
# parent_id: 'ribotie'
# parent_name: "RiboTIE"
# parent_description: "Overview of open reading frames called as translating by RiboTIE"
# """

START_CODON_MQC_HEADER = """
# id: 'ribotie_start_codon_counts'
# section_name: 'Start Codon'
# description: "Start codon counts of all open reading frames called by RiboTIE"
# plot_type: 'bargraph'
# anchor: 'orf_start_codon_counts'
# pconfig:
# id: "orf_start_codon_counts_plot"
# title: "RiboTIE: Start Codons"
# colors:
# ATG : "#f8d7da"
# xlab: "# ORFs"
# cpswitch_counts_label: "Number of ORFs"
"""

BIOTYPE_VARIANT_MQC_HEADER = """
# id: 'ribotie_biotype_counts_variant'
# section_name: 'Transcript Biotypes (varRNA-ORF)'
# description: "Transcript biotypes of varRNA-ORFs called by RiboTIE"
# plot_type: 'bargraph'
# anchor: 'transcript_biotype_variant_counts'
# pconfig:
# id: "transcript_biotype_counts_variant_plot"
# title: "RiboTIE: varRNA-ORFs Transcript Biotypes"
# xlab: "# ORFs"
# cpswitch_counts_label: "Number of ORFs"
"""

ORF_TYPE_MQC_HEADER = """
# id: 'ribotie_orftype_counts'
# section_name: 'ORF types'
# description: "ORF types of all open reading frames called by RiboTIE"
# plot_type: 'bargraph'
# anchor: 'transcript_orftype_counts'
# pconfig:
# id: "transcript_orftype_counts_plot"
# title: "RiboTIE: ORF Types"
# xlab: "# ORFs"
# cpswitch_counts_label: "Number of ORFs"
"""

ORF_LEN_MQC_HEADER = """
# id: 'ribotie_orflen_hist'
# section_name: 'ORF lengths'
# description: "ORF lengths of all open reading frames called by RiboTIE"
# plot_type: 'linegraph'
# anchor: 'transcript_orflength_hist'
# pconfig:
# id: "transcript_orflength_hist_plot"
# title: "RiboTIE: ORF lengths"
# xlab: "Length"
# xLog: "True"
"""

ORF_TYPE_ORDER = [
"annotated CDS",
"N-terminal truncation",
"N-terminal extension",
"C-terminal truncation",
"C-terminal extension",
"uORF",
"uoORF",
"dORF",
"doORF",
"intORF",
"lncRNA-ORF",
"varRNA-ORF",
]

ORF_BIOTYPE_ORDER = [
"retained_intron",
"protein_coding",
"protein_coding_CDS_not_defined",
"nonsense_mediated_decay",
"processed_pseudogene",
"unprocessed_pseudogene",
"transcribed_unprocessed_pseudogene",
"transcribed_processed_pseudogene",
"translated_processed_pseudogene",
"transcribed_unitary_pseudogene",
"processed_transcript",
"TEC",
"artifact",
"non_stop_decay",
"misc_RNA",
]
10 changes: 5 additions & 5 deletions transcript_transformer/argparser.py
Original file line number Diff line number Diff line change
@@ -271,7 +271,7 @@ def add_train_loading_args(self, pretrain=False, auto=False):
nargs="*",
default=[],
help="chromosomes used for training. If not specified, "
"training is performed on all available chromosomes excluding val/test contigs",
"training is performed on all available chromosomes excluding val/test seqnames",
)
dl_parse.add_argument(
"--val",
@@ -483,7 +483,7 @@ def parse_arguments(self, argv, configs=[]):
args.model_dir = model_dir
# create output dir if non-existent
if args.out_prefix:
os.makedirs(os.path.dirname(args.out_prefix), exist_ok=True)
os.makedirs(os.path.dirname(os.path.abspath(args.out_prefix)), exist_ok=True)
# backward compatibility
if "seq" in args:
args.use_seq = args.seq
@@ -517,7 +517,7 @@ def parse_arguments(self, argv, configs=[]):

# conditions used to remove transcripts from training/validation data
conds = {"global": {}, "grouped": [{} for l in range(len(args.ribo_ids))]}
conds["global"]["tr_len"] = lambda x: np.logical_and(
conds["global"]["transcript_len"] = lambda x: np.logical_and(
x > args.min_seq_len, x < args.max_seq_len
)
if args.cond is not None:
@@ -557,7 +557,7 @@ def parse_arguments(self, argv, configs=[]):
# Default values
args.exp_path = "transcript"
args.y_path = "tis"
args.seqn_path = "contig"
args.id_path = "id"
args.seqn_path = "seqname"
args.id_path = "transcript_id"
print(args)
return args
Original file line number Diff line number Diff line change
@@ -23,4 +23,5 @@ patience: 8
cond :
ribo:
num_reads : "x > 6"
has_annotated_start_codon: "x"

4 changes: 3 additions & 1 deletion transcript_transformer/configs/tis_transformer_defaults.yml
Original file line number Diff line number Diff line change
@@ -15,4 +15,6 @@ ff_glu: false
emb_dropout: 0.1
ff_dropout: 0.1
attn_dropout: 0.1
local_window_size: 256
local_window_size: 256
cond:
has_annotated_start_codon: "x"
464 changes: 222 additions & 242 deletions transcript_transformer/data.py

Large diffs are not rendered by default.

1,125 changes: 693 additions & 432 deletions transcript_transformer/processing.py

Large diffs are not rendered by default.

25 changes: 15 additions & 10 deletions transcript_transformer/ribotie.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
import sys
import numpy as np
import polars as pl
import yaml
import h5py
from importlib import resources as impresources
from copy import deepcopy

from .transcript_transformer import train, predict
from .argparser import Parser
from .pretrained import riboformer_models
from .pretrained import ribotie_models
from . import configs
from .data import process_seq_data, process_ribo_data
from .processing import construct_output_table, csv_to_gtf, create_multiqc_reports
@@ -22,7 +21,7 @@ def parse_args():
data_parser.add_argument(
"--prob_cutoff",
type=float,
default=0.15,
default=0.125,
help="Determines the minimum model output score required for model "
"predictions to be included in the result table.",
)
@@ -64,6 +63,11 @@ def parse_args():
action="store_true",
help="Include annotated CDS regions in generated GTF file containing predicted translated ORFs.",
)
data_parser.add_argument(
"--unfiltered",
action="store_true",
help="Don't apply filtering of top predictions",
)
# data_parser.add_argument(
# "--pretrained_model",
# type=json.loads,
@@ -84,7 +88,7 @@ def parse_args():
parser.add_comp_args()
parser.add_training_args()
parser.add_train_loading_args(pretrain=True, auto=True)
default_config = f"{impresources.files(configs) / 'riboformer_defaults.yml'}"
default_config = f"{impresources.files(configs) / 'ribotie_defaults.yml'}"
args = parser.parse_arguments(sys.argv[1:], [default_config])
if args.out_prefix is None:
args.out_prefix = f"{os.path.splitext(args.conf[0])[0]}_"
@@ -176,9 +180,9 @@ def main():
if not (args.data or args.results or args.pretrain):
if "pretrained_model" not in args:
args = load_args(
(impresources.files(riboformer_models) / "50perc_06_23.yml"), args
(impresources.files(ribotie_models) / "50perc_06_23.yml"), args
)
args.model_dir = str(impresources.files(riboformer_models)._paths[0])
args.model_dir = str(impresources.files(ribotie_models)._paths[0])
for i, ribo_set in enumerate(args.ribo_ids):
args_set = deepcopy(args)
args_set.ribo_ids = [ribo_set]
@@ -221,13 +225,14 @@ def main():
exclude_invalid_TTS=not args.include_invalid_TTS,
ribo=out,
parallel=args.parallel,
unfiltered=args.unfiltered,
)
if df is not None:
csv_to_gtf(
args.h5_path, pl.from_pandas(df), out_prefix, args.exclude_annotated
)
csv_to_gtf(args.h5_path, df, out_prefix, args.exclude_annotated)
os.makedirs(f"{args.out_prefix}/multiqc", exist_ok=True)
create_multiqc_reports(df, f"{args.out_prefix}/multiqc/{ribo_set_str}")
create_multiqc_reports(
df, f"{os.path.dirname(args.out_prefix)}/multiqc/{ribo_set_str}"
)


def merge_outputs(prefix, keys):
21 changes: 11 additions & 10 deletions transcript_transformer/tis_transformer.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,8 @@ def parse_args():
default_config = f"{impresources.files(configs) / 'tis_transformer_defaults.yml'}"
args = parser.parse_arguments(sys.argv[1:], [default_config])
if args.out_prefix is None:
args.out_prefix = f"{os.path.splitext(args.conf[0])[0]}_"
print(args.conf[0])
args.out_prefix = f"{os.path.splitext(args.h5_path)[0]}_"
assert ~args.results and ~args.data, (
"cannot only do processing of data and results, disable either"
" --data_process or --result_process"
@@ -66,8 +67,8 @@ def main():
args.input_type = "hdf5"
# determine optimal allocation of seqnames to train/val/test set
f = h5py.File(args.h5_path, "r")["transcript"]
contigs = np.array(f["contig"])
tr_lens = np.array(f["tr_len"])
contigs = np.array(f["seqname"])
tr_lens = np.array(f["transcript_len"])
f.file.close()
# determine nt count per seqname
contig_set = np.unique(contigs)
@@ -85,7 +86,7 @@ def main():

f = h5py.File(args.h5_path, "a")
grp = f["transcript"]
f_tr_ids = np.array(grp["id"])
f_tr_ids = np.array(grp["transcript_id"])
xsorted = np.argsort(f_tr_ids)
out = np.load(f"{prefix}.npy", allow_pickle=True)
tr_ids = np.hstack([o[0] for o in out])
@@ -96,25 +97,25 @@ def main():
for idx, (_, pred, _) in zip(pred_to_h5_args, out):
pred_arr[idx] = pred
dtype = h5py.vlen_dtype(np.dtype("float32"))
if "seq_output" in grp.keys():
if "tis_transformer_score" in grp.keys():
print("--> Overwriting results in local h5 database...")
del grp["seq_output"]
del grp["tis_transformer_score"]
else:
print("--> Writing results to local h5 database...")
grp.create_dataset("seq_output", data=pred_arr, dtype=dtype)
grp.create_dataset("tis_transformer_score", data=pred_arr, dtype=dtype)
f.close()
if not args.no_backup:
if not args.backup_path:
args.backup_path = os.path.splitext(args.gtf_path)[0] + ".h5"
if os.path.isfile(args.backup_path):
f = h5py.File(args.backup_path, "a")
grp = f["transcript"]
if "seq_output" in grp.keys():
if "tis_transformer_score" in grp.keys():
print("--> Overwriting results in backup h5 database...")
del grp["seq_output"]
del grp["tis_transformer_score"]
else:
print("--> Writing results to backup h5 database...")
grp.create_dataset("seq_output", data=pred_arr, dtype=dtype)
grp.create_dataset("tis_transformer_score", data=pred_arr, dtype=dtype)
f.close()
if not args.data:
f = h5py.File(args.h5_path, "r")
24 changes: 17 additions & 7 deletions transcript_transformer/transcript_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import h5py
import numpy as np
import torch
from h5max import load_sparse
from h5max import load_sparse_matrix
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from pdb import set_trace


def collate_fn(batch):
"""
@@ -56,7 +58,13 @@ def collate_fn(batch):
else:
arr = np.empty(shape=(len(batch[1]), max_len + 2), dtype=int)
for i, (x, l) in enumerate(zip(batch[1], max_len - lens)):
arr[i] = np.concatenate(([5], x[k], [6], [7] * l))
try:
arr[i] = np.concatenate(([5], x[k], [6], [7] * l))
except:
print(k, len(x[k]), l)
print(x[k])
print(batch[0][i])
set_trace()
x_dict[k] = torch.LongTensor(arr)

x_dict.update({"x_id": batch[0], "y": y_b})
@@ -186,13 +194,13 @@ def __init__(
def setup(self, stage=None):
f = h5py.File(self.h5_path, "r")[self.exp_path]
self.seqn_list = np.array(f[self.seqn_path])
self.transcript_lens = np.array(f["tr_len"])
self.transcript_lens = np.array(f["transcript_len"])
# evaluate conditions
# Identical mask over the samples applied to all datasets
global_masks = []
for key, cond_f in self.cond["global"].items():
mask = cond_f(np.array(f[key]))
if (key != "tr_len") and (self.leaky_frac > 0):
if (key != "transcript_len") and (self.leaky_frac > 0):
prob_mask = np.random.uniform(size=len(mask)) > (1 - self.leaky_frac)
mask[prob_mask] = True
global_masks.append(mask)
@@ -220,7 +228,7 @@ def setup(self, stage=None):
mask = cond_f(grouped_feature)
# if data filtering is not based on transcript length,
# allow leaky filtering
if (key != "tr_len") and (self.leaky_frac > 0):
if (key != "transcript_len") and (self.leaky_frac > 0):
prob_mask = np.random.uniform(size=len(mask)) > (
1 - self.leaky_frac
)
@@ -467,9 +475,11 @@ def __getitem__(self, index):
id_prefix = "&".join(ribo_set) + "|"
ribo_path = f"riboseq/{ribo_id}/5"
if self.parallel:
x = load_sparse(self.r[ribo_id][ribo_path], idx, format="csr").T
x = load_sparse_matrix(
self.r[ribo_id][ribo_path], idx, format="csr"
).T
else:
x = load_sparse(self.f[ribo_path], idx, format="csr").T
x = load_sparse_matrix(self.f[ribo_path], idx, format="csr").T
if self.offsets is not None:
for col_i, (_, shift) in enumerate(
self.offsets[ribo_id].items()
139 changes: 62 additions & 77 deletions transcript_transformer/util_functions.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,7 @@
from datetime import datetime
import numpy as np
import heapq


cdn_prot_dict = {
"ATA": "I",
"ATC": "I",
"ATT": "I",
"ATG": "M",
"ACA": "T",
"ACC": "T",
"ACG": "T",
"ACT": "T",
"AAC": "N",
"AAT": "N",
"AAA": "K",
"AAG": "K",
"AGC": "S",
"AGT": "S",
"AGA": "R",
"AGG": "R",
"CTA": "L",
"CTC": "L",
"CTG": "L",
"CTT": "L",
"CCA": "P",
"CCC": "P",
"CCG": "P",
"CCT": "P",
"CAC": "H",
"CAT": "H",
"CAA": "Q",
"CAG": "Q",
"CGA": "R",
"CGC": "R",
"CGG": "R",
"CGT": "R",
"GTA": "V",
"GTC": "V",
"GTG": "V",
"GTT": "V",
"GCA": "A",
"GCC": "A",
"GCG": "A",
"GCT": "A",
"GAC": "D",
"GAT": "D",
"GAA": "E",
"GAG": "E",
"GGA": "G",
"GGC": "G",
"GGG": "G",
"GGT": "G",
"TCA": "S",
"TCC": "S",
"TCG": "S",
"TCT": "S",
"TTC": "F",
"TTT": "F",
"TTA": "L",
"TTG": "L",
"TAC": "Y",
"TAT": "Y",
"TAA": "_",
"TAG": "_",
"TGC": "C",
"TGT": "C",
"TGA": "_",
"TGG": "W",
"NNN": "_",
}
from transcript_transformer import CDN_PROT_DICT, PROT_IDX_DICT, DNA_IDX_DICT


def construct_prot(seq):
@@ -91,26 +23,41 @@ def construct_prot(seq):
if "N" in cdn:
string += "_"
else:
string += cdn_prot_dict[cdn]
string += CDN_PROT_DICT[cdn]

return string, has_stop, stop_codon


def DNA2vec(dna_seq):
seq_dict = {"A": 0, "T": 1, "U": 1, "C": 2, "G": 3, "N": 4}
def DNA2vec(dna_seq, seq_dict=DNA_IDX_DICT):
dna_vec = np.zeros(len(dna_seq), dtype=int)
for idx in np.arange(len(dna_seq)):
dna_vec[idx] = seq_dict[dna_seq[idx]]

return dna_vec


def prot2vec(prot_seq, prot_dict=PROT_IDX_DICT):
prot_vec = np.zeros(len(prot_seq), dtype=int)
for idx in np.arange(len(prot_seq)):
prot_vec[idx] = prot_dict[prot_seq[idx]]

return list(prot_vec)


def listify(array):
return [list(a) for a in array]


def time():
return datetime.now().strftime("%H:%M:%S %m-%d ")


def vec2DNA(tr_seq, np_dict=np.array(["A", "T", "C", "G", "N"])):
return "".join(np_dict[tr_seq])
def vec2DNA(vec, np_dict=np.array(["A", "T", "C", "G", "N"])):
return "".join(np_dict[vec])


def vec2prot(vec, np_dict=np.array(list(PROT_IDX_DICT.keys()))):
return "".join(np_dict[vec])


def divide_keys_by_size(size_dict, num_chunks):
@@ -223,8 +170,8 @@ def transcript_region_to_exons(
stop coordinates must exist on exons.
Args:
start_coord (str): start coordinate
stop_coord (str): stop coordinate, NOT start of stop codon for CDSs
start_coord (int) start coordinate
stop_coord (int): stop coordinate, NOT start of stop codon for CDSs
strand (str): strand, either + or -
exons (list): list of exon bound coordinates following gtf file conventions.
E.g. positive strand: [1 2 4 5] negative strand: [4 5 1 2]
@@ -235,6 +182,8 @@ def transcript_region_to_exons(
E.g. [{start_coord} 2 4 {stop_coord}]
"""
pos_strand = strand == "+"
if type(exons) == list:
exons = np.array(exons)
if stop_coord == -1:
if pos_strand:
stop_coord = exons[-1]
@@ -263,7 +212,7 @@ def transcript_region_to_exons(
exon_idx_stop = len(exons) + 1
exon_numbers = np.arange(exon_idx_start // 2, exon_idx_stop // 2)

return genome_parts, exon_numbers + 1
return list(genome_parts), list(exon_numbers + 1)


def get_exon_dist_map(tr_regions, strand):
@@ -333,3 +282,39 @@ def find_distant_exon_coord(ref_coord, distance, strand, exons):
dist_coord = -1

return dist_coord


def get_str2str_idx_map(source, dest):
xsorted = np.argsort(dest)
return xsorted[np.searchsorted(dest[xsorted], source)]


def co_to_idx(start, end):
return start - 1, end


def slice_gen(
seq,
start,
end,
strand,
co=True,
to_vec=True,
seq_dict={"A": 0, "T": 1, "C": 2, "G": 3, "N": 4},
comp_dict={0: 1, 1: 0, 2: 3, 3: 2, 4: 4},
):
"""get sequence following gtf-coordinate system"""
if co:
start, end = co_to_idx(start, end)
sl = seq[start:end].seq

if to_vec:
sl = list(map(lambda x: seq_dict[x.upper()], sl))

if strand in ["-", -1, False]:
if comp_dict is not None:
sl = list(map(lambda x: comp_dict[x], sl))[::-1]
else:
sl = sl[::-1]

return np.array(sl)

0 comments on commit aee76cb

Please sign in to comment.