Skip to content

Commit

Permalink
Update variant_statistics.py
Browse files Browse the repository at this point in the history
update def per_sample_datatype() to 1) return a float 2) include documentation 3) update some return logic
  • Loading branch information
matren395 committed Jan 17, 2024
1 parent e6a242f commit f1b383f
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions gnomad_qc/v4/variant_qc/variant_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ def per_sample_datatype(
data_type: str = "exomes",
test: bool = False,
overwrite: bool = False,
) -> None:
) -> float:
"""
Return a float of the mean number of variants called per sample for a chosen data type .
:param data_type: String of either "exomes" or "genomes" for the sample type.
:param test: Boolean for if you would like to use a small test set.
:param overwrite: Boolean to overwrite checkpoint files if requested.
"""

# Read in release data and metadata and VDS
ht_release_data = release_sites(data_type).ht()
if test:
Expand All @@ -54,25 +62,30 @@ def per_sample_datatype(

# Select data down to: only variant locus & alleles and only if samples do or don't have it
# Create then select down to GT
vds_data_filtered = vds_data_filtered.annotate_entries(GT=hl.vds.lgt_to_gt(vds_data_filtered.LGT, vds_data_filtered.LA))
vds_data_filtered = vds_data_filtered.annotate_entries(
GT=hl.vds.lgt_to_gt(vds_data_filtered.LGT, vds_data_filtered.LA)
)
vds_data_filtered = vds_data_filtered.select_entries("GT")
vds_data_filtered = vds_data_filtered.select_rows()
vds_data_filtered = (
vds_data_filtered.select_cols()
)
vds_data_filtered = vds_data_filtered.select_cols()

# Perform Hail's Sample QC module
sample_qc_ht = hl.sample_qc(vds_data_filtered).cols().key_by("s")

sample_qc_ht = sample_qc_ht.checkpoint(
f"gs://gnomad-tmp-4day/sample_qc_mt_per_{data_type}_sample.ht", overwrite=overwrite
f"gs://gnomad-tmp-4day/sample_qc_mt_per_{data_type}_sample.ht",
overwrite=overwrite,
)

# Column 'n_called' is a per-sample metric of the number of variants called
# Is what we want to report and return
mean_called_per = sample_qc_ht.aggregate(
hl.agg.mean(sample_qc_ht.sample_qc.n_called)
)

print("mean called per exome: ", mean_called_per)
logger.info(f"mean called per exome: {mean_called_per}")

return mean_called_per


def variant_types():
Expand Down

0 comments on commit f1b383f

Please sign in to comment.