Skip to content

Commit

Permalink
Merge pull request jeromekelleher#309 from jeromekelleher/better-line…
Browse files Browse the repository at this point in the history
…age-logging

Better lineage logging
  • Loading branch information
jeromekelleher authored Sep 26, 2024
2 parents 28d4231 + 7f9848f commit 010b1aa
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 42 deletions.
10 changes: 5 additions & 5 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ options+="--num-mismatches $mismatches"

mkdir -p $resultsdir

# date=2000-01-01
# last_ts=$resultsdir/initial.ts
# python3 -m sc2ts initialise $last_ts $matches
date=2000-01-01
last_ts=$resultsdir/initial.ts
python3 -m sc2ts initialise $last_ts $matches

date=2020-03-01
last_ts="$results_prefix$date".ts
# date=2020-03-01
# last_ts="$results_prefix$date".ts

dates=`python3 -m sc2ts list-dates --after $date $metadata | grep -v 2021-12-31`
for date in $dates; do
Expand Down
38 changes: 22 additions & 16 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import tszip
import tsinfer
import click
import daiquiri
import humanize
import pandas as pd

Expand Down Expand Up @@ -100,20 +99,28 @@ def setup_logging(verbosity, log_file=None):
log_level = "INFO"
if verbosity > 1:
log_level = "DEBUG"
outputs = ["stderr"]
handler = logging.StreamHandler()
if log_file is not None:
outputs = [daiquiri.output.File(log_file)]
# Note using set_excepthook=False means that we don't write errors
# to the log, so if something happens we'll only see it if we look
# at the console output. For development this is better than having
# to go to the log to see the traceback, but for production it may
# be better to let daiquiri record the errors as well.
daiquiri.setup(outputs=outputs, set_excepthook=False)
# Only show stuff coming from sc2ts and the relevant bits of tsinfer.
logger = logging.getLogger("sc2ts")
logger.setLevel(log_level)
logger = logging.getLogger("tsinfer.inference")
logger.setLevel(log_level)
handler = logging.FileHandler(log_file)
# default time format has millisecond precision which we don't need
time_format = "%Y-%m-%d %H:%M:%S"
fmt = logging.Formatter(
"%(asctime)s %(levelname)s %(name)s %(message)s", datefmt=time_format
)
handler.setFormatter(fmt)

# This is mainly used to output messages about major events. Possibly
# should do this with a separate logger entirely, rather than use
# the "WARNING" channel.
warn_handler = logging.StreamHandler()
warn_handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
warn_handler.setLevel(logging.WARN)

for name in ["sc2ts", "tsinfer.inference"]:
logger = logging.getLogger(name)
logger.setLevel(log_level)
logger.addHandler(handler)
logger.addHandler(warn_handler)


# TODO add options to list keys, dump specific alignments etc
Expand Down Expand Up @@ -325,7 +332,7 @@ def summarise_base(ts, date, progress):
default=2,
show_default=True,
type=int,
help="Minimum number of shared mutations for reconsidered sample groups"
help="Minimum number of shared mutations for reconsidered sample groups",
)
@click.option(
"--retrospective-window",
Expand Down Expand Up @@ -505,7 +512,6 @@ def tally_lineages(ts, metadata, verbose):
df.to_csv(sys.stdout, sep="\t", index=False)



def examine_recombinant(work):
base_ts = tszip.load(work.ts_path)
# NOTE: this is needed because we have to have all the sites in the trees
Expand Down
61 changes: 40 additions & 21 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def increment_time(date, ts):
class Sample:
strain: str
date: str
pango: str = "Unknown"
metadata: Dict = dataclasses.field(default_factory=dict)
alignment_qc: Dict = dataclasses.field(default_factory=dict)
masked_sites: List = dataclasses.field(default_factory=list)
Expand All @@ -346,9 +347,8 @@ def is_recombinant(self):
return len(self.hmm_match.path) > 1

def summary(self):
pango = self.metadata.get("Viridian_pangolin", "Unknown")
hmm_match = "No match" if self.hmm_match is None else self.hmm_match.summary()
s = f"{self.strain} {self.date} {pango} {hmm_match}"
s = f"{self.strain} {self.date} {self.pango} {hmm_match}"
for name, hmm_match in self.hmm_reruns.items():
s += f"; {name}: {hmm_match.summary()}"
return s
Expand Down Expand Up @@ -472,7 +472,14 @@ def check_base_ts(ts):
assert len(sc2ts_md["samples_strain"]) == ts.num_samples


def preprocess(samples_md, base_ts, date, alignment_store, show_progress=False):
def preprocess(
samples_md,
base_ts,
date,
alignment_store,
pango_lineage_key="pango",
show_progress=False,
):
keep_sites = base_ts.sites_position.astype(int)
problematic_sites = core.get_problematic_sites()

Expand All @@ -485,7 +492,9 @@ def preprocess(samples_md, base_ts, date, alignment_store, show_progress=False):
except KeyError:
logger.debug(f"No alignment stored for {strain}")
continue
sample = Sample(strain, date, metadata=md)
sample = Sample(
strain, date, md.get(pango_lineage_key, "PangoUnknown"), metadata=md
)
ma = alignments.encode_and_mask(alignment)
# Always mask the problematic_sites as well. We need to do this
# for follow-up matching to inspect recombinants, as tsinfer
Expand Down Expand Up @@ -571,6 +580,7 @@ def extend(
base_ts,
date,
alignment_store,
pango_lineage_key="Viridian_pangolin", # TODO parametrise
show_progress=show_progress,
)

Expand Down Expand Up @@ -598,7 +608,7 @@ def extend(
ts = add_exact_matches(ts=ts, match_db=match_db, date=date)

logger.info(f"Update ARG with low-cost samples for {date}")
ts = add_matching_results(
ts, _ = add_matching_results(
f"match_date=='{date}' and hmm_cost>0 and hmm_cost<={hmm_cost_threshold}",
ts=ts,
match_db=match_db,
Expand All @@ -612,7 +622,7 @@ def extend(
logger.info("Looking for retrospective matches")
assert min_group_size is not None
earliest_date = parse_date(date) - datetime.timedelta(days=retrospective_window)
ts = add_matching_results(
ts, groups = add_matching_results(
f"match_date<'{date}' AND match_date>'{earliest_date}'",
ts=ts,
match_db=match_db,
Expand All @@ -625,6 +635,8 @@ def extend(
show_progress=show_progress,
phase="retro",
)
for group in groups:
logger.warning(f"Add retro group {dict(group.pango_count)}")
return update_top_level_metadata(ts, date)


Expand Down Expand Up @@ -728,22 +740,25 @@ class SampleGroup:
immediate_reversions: List = None
additional_keys: Dict = None
sample_hash: str = None
date_count: dict = dataclasses.field(default_factory=collections.Counter)

def __post_init__(self):
strains = []
for s in self.samples:
self.date_count[s.date] += 1
strains.append(s.strain)
m = hashlib.md5()
for strain in sorted(strains):
for strain in sorted(self.strains):
m.update(strain.encode())
self.sample_hash = m.hexdigest()

@property
def strains(self):
return [s.strain for s in self.samples]

@property
def date_count(self):
return collections.Counter([s.date for s in self.samples])

@property
def pango_count(self):
return collections.Counter([s.pango for s in self.samples])

def __len__(self):
return len(self.samples)

Expand All @@ -752,11 +767,12 @@ def __iter__(self):

def summary(self):
return (
f"Group {self.sample_hash} {len(self.samples)} samples "
f"{self.sample_hash} n={len(self.samples)} "
f"{dict(self.date_count)} "
f"attaching at {path_summary(self.path)}, "
f"immediate_reversions={self.immediate_reversions}, "
f"additional_keys={self.additional_keys};"
f"{dict(self.pango_count)} "
f"immediate_reversions={self.immediate_reversions} "
f"additional_keys={self.additional_keys} "
f"path={path_summary(self.path)} "
f"strains={self.strains}"
)

Expand Down Expand Up @@ -796,7 +812,7 @@ def add_matching_results(

if num_samples == 0:
logger.info("No candidate samples found in MatchDb")
return ts
return ts, []

groups = [
SampleGroup(
Expand All @@ -812,6 +828,7 @@ def add_matching_results(
tables = ts.dump_tables()

attach_nodes = []
added_groups = []
with get_progress(groups, date, f"add({phase})", show_progress) as bar:
for group in bar:
if (
Expand Down Expand Up @@ -852,12 +869,14 @@ def add_matching_results(
attach_depth = max(tree.depth(u) for u in poly_ts.samples())
nodes = attach_tree(ts, tables, group, poly_ts, date, additional_node_flags)
logger.debug(
f"Attach {phase} {group.summary()}; "
f"depth={attach_depth} total_muts{poly_ts.num_mutations} "
f"Attach {phase} "
f"depth={attach_depth} total_muts={poly_ts.num_mutations} "
f"root_muts={num_root_mutations} "
f"recurrent_muts={num_recurrent_mutations} attach_nodes={nodes}"
f"recurrent_muts={num_recurrent_mutations} attach_nodes={len(nodes)} "
f"group={group.summary()}"
)
attach_nodes.extend(nodes)
added_groups.append(group)

# Update the sites with metadata for these newly added samples.
tables.sites.clear()
Expand All @@ -880,7 +899,7 @@ def add_matching_results(
ts = push_up_reversions(ts, attach_nodes, date)
ts = coalesce_mutations(ts, attach_nodes)
ts = delete_immediate_reversion_nodes(ts, attach_nodes)
return ts
return ts, added_groups


def solve_num_mismatches(k):
Expand Down

0 comments on commit 010b1aa

Please sign in to comment.