Skip to content

Commit

Permalink
major update: refactor, data performance, tis transformer functionali…
Browse files Browse the repository at this point in the history
…ty, formatting, bam support, yaml file restructure, etc.
  • Loading branch information
jdcla committed Nov 27, 2023
1 parent ff661bd commit 0bee155
Show file tree
Hide file tree
Showing 12 changed files with 1,483 additions and 724 deletions.
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ install_requires =
numpy >= 1.21.0
scipy >= 1.7.3
pandas >= 2.0.1
biobear >= 0.13.1
pyarrow >= 12.0.1
h5max >= 0.3.0
fasta-reader >= 3.0.1
polars == 0.16.13
polars >= 0.16.13
tqdm >= 4.65.0
gtfparse >= 2.0.1
pyfaidx >= 0.7.2.1
Expand All @@ -33,3 +34,4 @@ install_requires =
console_scripts =
transcript_transformer = transcript_transformer.transcript_transformer:main
riboformer = transcript_transformer.riboformer:main
tis_transformer = transcript_transformer.tis_transformer:main
51 changes: 29 additions & 22 deletions template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ ribo_paths :
########################################################
h5_path : path/to/hdf5_file.h5
########################################################
## For models using transcript sequence data.
## This setting is used by the TIS-transformer.
## Set to false when training on ribo-seq data
## path prefix used for output files predictions
## defaults to hdf5 path
########################################################
# out_prefix : riboformer/template_
########################################################
seq : false
#
####################
## ADVANCED SETUP ##
Expand All @@ -28,17 +28,26 @@ seq : false
########################################################
## A custom set of riboseq data selected for training.
## Use ids applied in ribo_paths, leave commented if NA.
## Replicates can be merged where the number of mapped
## reads are summed for multiple experiments.
########################################################
## example: only use SRR000001 and SRR000003
#ribo:
# - SRR000001
# - SRR000003
#
## example: SRR000001 and SRR000002 are merged (replicates)
#ribo:
# SRR000001
# SRR000003
# - - SRR000001
# - SRR000002
# - - SRR000003
#
########################################################
## It is also possible to set offsets per read length.
## It is possible to set offsets per read length.
## NOT RECOMMENDED: loses read length information.
## Functionality exists merely for benchmark
## Functionality exists merely for benchmarking
########################################################
#ribo:
#offsets:
# SRR000001:
# 28 : 7
# 29 : 10
Expand All @@ -47,20 +56,18 @@ seq : false
########################################################
## Training times can be sped up by removing transcripts
## with few reads. This does not affect samples within
## the validation/test set. Filtering is performed based
## the test set. Filtering is performed based
## on the number of reads on a transcript.
## format: riboseq/{id}/5/num_reads > "lambda x: x > {cutoff}"
########################################################
## example: ommit readless transcripts during training/validation
#cond :
# riboseq/SRR000001/5/num_reads : "lambda x: x > 0"
# riboseq/SRR000002/5/num_reads : "lambda x: x > 0"
# riboseq/SRR000003/5/num_reads : "lambda x: x > 0"
# ribo:
# num_reads : x > 0
#
########################################################
## Replicates can be merged where the number of mapped
## reads are summed for multiple experiments.
## expects list of merged sets based on rank in 'ribo' or
## 'ribo_paths'
########################################################
#merged: [0, 1, 2]
#
## example: custom rules per data set
#cond :
# ribo:
# num_reads :
# SRR000001 : "x > 10"
# SRR000002 : "x > 0"
# SRR000003 : "x > 0"
146 changes: 79 additions & 67 deletions transcript_transformer/argparser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import json
import yaml
import argparse
Expand Down Expand Up @@ -57,7 +56,7 @@ def add_comp_args(self):
comp_parse.add_argument(
"--max_memory",
type=int,
default=24000,
default=30000,
help="Value (GPU vRAM) used to bucket batches based on rough estimates. "
"Reduce this setting if running out of memory",
)
Expand All @@ -83,7 +82,7 @@ def add_architecture_args(self):
tf_parse.add_argument(
"--num_tokens",
type=int,
default=5,
default=8,
help="number of unique nucleotide input tokens (for sequence input)",
)
tf_parse.add_argument(
Expand Down Expand Up @@ -207,28 +206,33 @@ def add_train_loading_args(self, pretrain=False):
type=str,
help="Path to checkpoint pretrained model",
)
dl_parse.add_argument(
"--train",
type=str,
nargs="*",
default=[],
help="chromosomes used for training. If not specified, "
"training is performed on all available chromosomes excluding val/test contigs",
)
dl_parse.add_argument(
"--val",
type=str,
nargs="*",
default=[],
help="chromosomes used for validation",
)
dl_parse.add_argument(
"--test",
type=str,
nargs="*",
default=[],
help="chromosomes used for testing",
)
dl_parse.add_argument(
"--train",
type=str,
nargs="*",
default=[],
help="chromosomes used for training. If not specified, "
"training is performed on all available chromosomes excluding val/test contigs",
)
dl_parse.add_argument(
"--val",
type=str,
nargs="*",
default=[],
help="chromosomes used for validation",
)
dl_parse.add_argument(
"--test",
type=str,
nargs="*",
default=[],
help="chromosomes used for testing",
)
dl_parse.add_argument(
"--strict_validation",
action="store_true",
help="does not apply custom loading filters (see 'cond') defined in config file to validation set",
)
dl_parse.add_argument(
"--leaky_frac",
type=float,
Expand All @@ -254,11 +258,6 @@ def add_train_loading_args(self, pretrain=False):
default=2000,
help="maximum of transcripts per batch",
)
dl_parse.add_argument(
"--ribo_offset",
action="store_true",
help="offset mapped ribosome reads by read length",
)

def add_training_args(self):
tr_parse = self.add_argument_group("Model training arguments")
Expand Down Expand Up @@ -293,7 +292,7 @@ def add_training_args(self):
tr_parse.add_argument(
"--patience",
type=int,
default=8,
default=5,
help="Number of epochs required without the validation loss reducing"
"to stop training",
)
Expand Down Expand Up @@ -344,7 +343,7 @@ def add_custom_data_args(self):
def add_preds_args(self):
pr_parse = self.add_argument_group("Model prediction processing arguments")
pr_parse.add_argument(
"--prob_th",
"--min_prob",
type=float,
default=0.01,
help="minimum prediction threshold at which additional information is processed",
Expand All @@ -353,14 +352,7 @@ def add_preds_args(self):
"--out_prefix",
type=str,
default="results",
help="path (prefix) of output files",
)
pr_parse.add_argument(
"--output_type",
type=str,
default="npy",
choices=["npy", "h5"],
help="file type of raw model predictions",
help="path (prefix) of output files, ignored if using config input file",
)

def add_misc_args(self):
Expand All @@ -378,6 +370,10 @@ def parse_config_file(args):
args.y_path = "tis"
args.seqn_path = "contig"
args.id_path = "id"
args.out_prefix = None
args.offsets = None
args.ribo_ids = []
args.cond = None

# read dict and add to args
with open(args.input_config, "r") as fh:
Expand All @@ -397,37 +393,53 @@ def parse_config_file(args):
)
cond_2 = ("ribo" in args) and (len(args.ribo) > 0)
args.use_ribo = cond_1 or cond_2
args.ribo_shifts = {}

# ribo takes precedence over ribo_paths, can also includes read shifts
# ribo takes precedence over ribo_paths
if args.use_ribo:
if cond_2:
if type(args.ribo) == dict:
args.ribo_ids = list(args.ribo.keys())
args.ribo_shifts = args.ribo
else:
args.ribo_ids = args.ribo
if "ribo" in args:
args.ribo_ids = [r if type(r) == list else [r] for r in args.ribo]
else:
args.ribo_ids = list(args.ribo_paths.keys())
else:
args.ribo_ids = []

# conditions used to remove transcripts from training data
if "cond" in input_config.keys():
args.cond = {k: eval(v) for k, v in input_config["cond"].items()}
else:
args.cond = None
args.ribo_ids = [[r] for r in args.ribo_paths.keys()]
flat_ids = sum(args.ribo_ids, [])
assert len(np.unique(flat_ids)) == len(
flat_ids
), "ribo_id is used multiple times"

# creates sets of merged (or solo) datasets based on merged input
if "merge" not in input_config.keys() or not args.use_ribo:
args.merge = []
args.merge_dict = {}
merg_mask = ~np.isin(
np.arange(len(args.ribo_ids)), list(itertools.chain(*args.merge))
# conditions used to remove transcripts from training/validation data
conds = {"global": {}, "grouped": [{} for l in range(len(args.ribo_ids))]}
conds["global"]["tr_len"] = lambda x: np.logical_and(
x > args.min_seq_len, x < args.max_seq_len
)
for data_idx in np.where(merg_mask)[0]:
args.merge += [[data_idx]]
for i, set in enumerate(args.merge):
args.merge_dict[i] = set
if args.cond is not None:
if "ribo" in args.cond.keys() and args.use_ribo:
# key is overwritten if condition present for multiple group members
for key, item in args.cond["ribo"].items():
if type(item) == dict:
# add condition to listed data sets
for id, cond in item.items():
grp_idx = [
i for i, grp in enumerate(args.ribo_ids) if id in grp
]
tmp_dict = {f"{key}": lambda x: eval(cond)}
conds["grouped"][grp_idx[0]].update(tmp_dict)
else:
# add condition to all groups
for grp_idx, grp in enumerate(args.ribo_ids):
tmp_dict = {f"{key}": lambda x: eval(item)}
conds["grouped"][grp_idx].update(tmp_dict)
del args.cond["ribo"]
for key, item in args.cond.items():
if type(item) == dict:
# add condition to listed data sets
for id, cond in item.items():
grp_idx = [i for i, grp in enumerate(args.ribo_ids) if id in grp]
tmp_dict = {f"{key}": lambda x: eval(cond)}
conds["grouped"][grp_idx[0]].update(tmp_dict)
else:
# add global condition
tmp_dict = {f"{key}": lambda x: eval(item)}
conds["global"].update(tmp_dict)

args.cond = conds

return args
Loading

0 comments on commit 0bee155

Please sign in to comment.