Skip to content

Commit

Permalink
Merge pull request jeromekelleher#320 from jeromekelleher/dont-mask
Browse files Browse the repository at this point in the history
Quick hack to turn off masking
  • Loading branch information
jeromekelleher authored Oct 1, 2024
2 parents 3fa1e58 + 020fecc commit aed24bc
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 362 deletions.
268 changes: 185 additions & 83 deletions notebooks/test_ts.ipynb

Large diffs are not rendered by default.

83 changes: 0 additions & 83 deletions sc2ts/alignments.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,15 @@
import logging
import pathlib
import dataclasses
import collections.abc
import hashlib
import bz2

import lmdb
import numba
import tqdm
import numpy as np

from . import core

logger = logging.getLogger(__name__)

GAP = core.ALLELES.index("-")
MISSING = -1


@numba.njit
def mask_alignment(a, start=0, window_size=7):
"""
Following the approach in fa2vcf, if any base is has two or more ambiguous
or gap characters with distance window_size of it, mark it as missing data.
"""
if window_size < 1:
raise ValueError("Window must be >= 1")
b = a.copy()
n = len(a)
masked_sites = []
for j in range(start, n):
ambiguous = 0
k = j - 1
while k >= start and k >= j - window_size:
if b[k] == GAP or b[k] == MISSING:
ambiguous += 1
k -= 1
k = j + 1
while k < n and k <= j + window_size:
if b[k] == GAP or b[k] == MISSING:
ambiguous += 1
k += 1
if ambiguous > 1:
a[j] = MISSING
masked_sites.append(j)
return masked_sites


def encode_alignment(h):
# Map anything that's not ACGT- to N
Expand All @@ -62,20 +26,6 @@ def decode_alignment(a):
return alleles[a]


def base_composition(haplotype, excluded_sites=None):
"""
Haplotype includes an arbitrary character at the start.
Also, excluded site positions are 1-based.
"""
if excluded_sites is not None:
mask = np.zeros(len(haplotype), dtype=bool)
mask[excluded_sites] = True
# Remove the first site from both haplotype and mask.
masked_haplotype = haplotype[1:][~mask[1:]]
return collections.Counter(masked_haplotype)
return collections.Counter(haplotype[1:])


def compress_alignment(a):
return bz2.compress(a.astype("S"))

Expand Down Expand Up @@ -147,36 +97,3 @@ def __iter__(self):
def __len__(self):
with self.env.begin() as txn:
return txn.stat()["entries"]


@dataclasses.dataclass
class MaskedAlignment:
alignment: np.ndarray
masked_sites: np.ndarray
original_base_composition: dict
original_md5: str
masked_base_composition: str

def qc_summary(self):
return {
"num_masked_sites": self.masked_sites.shape[0],
"original_base_composition": self.original_base_composition,
"original_md5": self.original_md5,
"masked_base_composition": self.masked_base_composition,
}


def encode_and_mask(alignment, window_size=7):
# TODO make window_size param
a = encode_alignment(alignment)
masked_sites = mask_alignment(a, start=1, window_size=window_size)
return MaskedAlignment(
alignment=a,
masked_sites=np.array(masked_sites, dtype=int),
original_base_composition=base_composition(haplotype=alignment),
original_md5=hashlib.md5(alignment[1:]).hexdigest(),
masked_base_composition=base_composition(
haplotype=alignment,
excluded_sites=masked_sites,
),
)
10 changes: 10 additions & 0 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,14 @@ def summarise_base(ts, date, progress):
"is greater than this, randomly subsample."
),
)
@click.option(
"--max-missing-sites",
default=None,
type=int,
help=(
"The maximum number of missing sites in a sample to be accepted for inclusion"
),
)
@click.option(
"--random-seed",
default=42,
Expand Down Expand Up @@ -386,6 +394,7 @@ def extend(
min_root_mutations,
retrospective_window,
max_daily_samples,
max_missing_sites,
num_threads,
random_seed,
progress,
Expand Down Expand Up @@ -427,6 +436,7 @@ def extend(
min_root_mutations=min_root_mutations,
retrospective_window=retrospective_window,
max_daily_samples=max_daily_samples,
max_missing_sites=max_missing_sites,
random_seed=random_seed,
num_threads=num_threads,
show_progress=progress,
Expand Down
109 changes: 54 additions & 55 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def initial_ts(additional_problematic_sites=list()):
# 1-based coordinates
for pos in range(1, L):
if pos not in problematic_sites:
tables.sites.add_row(pos, reference[pos], metadata={"masked_samples": 0})
tables.sites.add_row(pos, reference[pos], metadata={"missing_samples": 0})
# TODO should probably make the ultimate ancestor time something less
# plausible or at least configurable. However, this will be removed
# in later versions when we remove the dependence on tsinfer.
Expand Down Expand Up @@ -332,19 +332,19 @@ class Sample:
date: str
pango: str = "Unknown"
metadata: Dict = dataclasses.field(default_factory=dict)
alignment_qc: Dict = dataclasses.field(default_factory=dict)
masked_sites: List = dataclasses.field(default_factory=list)
# FIXME need a better name for this, as it's a different thing
# the original alignment. Haplotype is probably good, as it's
# what it would be in the tskit/tsinfer world.
alignment: List = None
alignment_composition: Dict = None
haplotype: List = None
hmm_match: HmmMatch = None
hmm_reruns: Dict = dataclasses.field(default_factory=dict)

@property
def is_recombinant(self):
return len(self.hmm_match.path) > 1

@property
def num_missing_sites(self):
return int(np.sum(self.haplotype == -1))

def summary(self):
hmm_match = "No match" if self.hmm_match is None else self.hmm_match.summary()
s = f"{self.strain} {self.date} {self.pango} {hmm_match}"
Expand Down Expand Up @@ -471,16 +471,32 @@ def check_base_ts(ts):
assert len(sc2ts_md["samples_strain"]) == ts.num_samples


def make_sample(strain, date, pango, metadata, alignment):

sample = Sample(
strain,
date,
pango,
metadata,
haplotype=alignments.encode_alignment(alignment),
# Need to do this here because encoding gets rid of
# ambiguous bases etc.
alignment_composition=collections.Counter(alignment),
)

return sample


def preprocess(
samples_md,
base_ts,
date,
alignment_store,
pango_lineage_key="pango",
show_progress=False,
max_missing_sites=np.inf,
):
keep_sites = base_ts.sites_position.astype(int)
problematic_sites = core.get_problematic_sites()

samples = []
with get_progress(samples_md, date, "preprocess", show_progress) as bar:
Expand All @@ -491,29 +507,17 @@ def preprocess(
except KeyError:
logger.debug(f"No alignment stored for {strain}")
continue
sample = Sample(
strain, date, md.get(pango_lineage_key, "PangoUnknown"), metadata=md
)
ma = alignments.encode_and_mask(alignment)
# Always mask the problematic_sites as well. We need to do this
# for follow-up matching to inspect recombinants, as tsinfer
# needs us to keep all sites in the table when doing mirrored
# coordinates.
ma.alignment[problematic_sites] = -1
sample.alignment_qc = ma.qc_summary()
sample.masked_sites = ma.masked_sites
sample.alignment = ma.alignment[keep_sites]
samples.append(sample)
num_Ns = ma.original_base_composition.get("N", 0)
non_nuc_counts = dict(ma.original_base_composition)
for nuc in "ACGT":
non_nuc_counts.pop(nuc, None)
counts = ",".join(
f"{key}={count}" for key, count in sorted(non_nuc_counts.items())
pango = md.get(pango_lineage_key, "PangoUnknown")
# NOTE everything we store about the sample is **excluding** the problematic_sites
sample = make_sample(strain, date, pango, md, alignment[keep_sites])
num_missing_sites = sample.num_missing_sites
logger.debug(f"Encoded {strain} {pango} missing={num_missing_sites}")
if sample.num_missing_sites <= max_missing_sites:
samples.append(sample)
else:
logger.debug(
f"Filter {strain}: missing={num_missing_sites} > {max_missing_sites}"
)
num_masked = len(ma.masked_sites)
logger.debug(f"Mask {strain}: masked={num_masked} {counts}")

return samples


Expand All @@ -531,6 +535,7 @@ def extend(
max_daily_samples=None,
show_progress=False,
retrospective_window=None,
max_missing_sites=None,
random_seed=42,
num_threads=0,
):
Expand All @@ -544,6 +549,8 @@ def extend(
min_root_mutations = 2
if retrospective_window is None:
retrospective_window = 30
if max_missing_sites is None:
max_missing_sites = np.inf

check_base_ts(base_ts)
logger.info(
Expand All @@ -554,35 +561,25 @@ def extend(
metadata_matches = list(metadata_db.get(date))

logger.info(f"Got {len(metadata_matches)} metadata matches")
# first check for samples that are in the alignment_store
samples_with_aligments = []
for md in metadata_matches:
if md["strain"] in alignment_store:
samples_with_aligments.append(md)

logger.info(f"Verified {len(samples_with_aligments)} have alignments")
# metadata_matches = list(
# metadata_db.query("SELECT * FROM samples WHERE strain=='SRR19463295'")
# )
if max_daily_samples is not None:
if max_daily_samples < len(samples_with_aligments):
seed_prefix = bytes(np.array([random_seed], dtype=int).data)
seed_suffix = hashlib.sha256(date.encode()).digest()
rng = random.Random(seed_prefix + seed_suffix)
samples_with_aligments = rng.sample(
samples_with_aligments, max_daily_samples
)
logger.info(f"Subset to {len(metadata_matches)} samples")

samples = preprocess(
samples_with_aligments,
metadata_matches,
base_ts,
date,
alignment_store,
pango_lineage_key="Viridian_pangolin", # TODO parametrise
show_progress=show_progress,
max_missing_sites=max_missing_sites,
)

if max_daily_samples is not None:
if max_daily_samples < len(samples):
seed_prefix = bytes(np.array([random_seed], dtype=int).data)
seed_suffix = hashlib.sha256(date.encode()).digest()
rng = random.Random(seed_prefix + seed_suffix)
samples = rng.sample(samples, max_daily_samples)
logger.info(f"Subset to {len(metadata_matches)} samples")

if len(samples) == 0:
logger.warning(f"Nothing to do for {date}")
return base_ts
Expand Down Expand Up @@ -657,7 +654,8 @@ def add_sample_to_tables(sample, tables, flags=tskit.NODE_IS_SAMPLE, group_id=No
sc2ts_md = {
"hmm_match": sample.hmm_match.asdict(),
"hmm_reruns": {k: m.asdict() for k, m in sample.hmm_reruns.items()},
"qc": sample.alignment_qc,
"alignment_composition": dict(sample.alignment_composition),
"num_missing_sites": sample.num_missing_sites,
}
if group_id is not None:
sc2ts_md["group_id"] = group_id
Expand Down Expand Up @@ -793,7 +791,7 @@ def add_matching_results(

# Group matches by path and set of immediate reversions.
grouped_matches = collections.defaultdict(list)
site_masked_samples = np.zeros(int(ts.sequence_length), dtype=int)
site_missing_samples = np.zeros(int(ts.sequence_length), dtype=int)
num_samples = 0
for sample in match_db.get(where_clause):
path = tuple(sample.hmm_match.path)
Expand Down Expand Up @@ -841,7 +839,8 @@ def add_matching_results(
continue

for sample in group:
site_masked_samples[sample.masked_sites] += 1
missing_sites = np.where(sample.haplotype == -1)[0]
site_missing_samples[missing_sites] += 1

flat_ts = match_path_ts(group)
if flat_ts.num_mutations == 0 or flat_ts.num_samples == 1:
Expand Down Expand Up @@ -881,7 +880,7 @@ def add_matching_results(
tables.sites.clear()
for site in ts.sites():
md = site.metadata
md["masked_samples"] += int(site_masked_samples[int(site.position)])
md["missing_samples"] += int(site_missing_samples[int(site.position)])
tables.sites.append(site.replace(metadata=md))

# NOTE: Doing the parsimony hueristic updates really is complicated a lot
Expand Down Expand Up @@ -1074,7 +1073,7 @@ def match_tsinfer(
):
if len(samples) == 0:
return []
genotypes = np.array([sample.alignment for sample in samples], dtype=np.int8).T
genotypes = np.array([sample.haplotype for sample in samples], dtype=np.int8).T
input_ts = ts
if mirror_coordinates:
ts = mirror_ts_coordinates(ts)
Expand Down
Loading

0 comments on commit aed24bc

Please sign in to comment.