Skip to content

Commit

Permalink
Merge pull request jeromekelleher#352 from jeromekelleher/deletion-ov…
Browse files Browse the repository at this point in the history
…erlaps

Deletion overlaps
  • Loading branch information
jeromekelleher authored Oct 9, 2024
2 parents bdf0a94 + d902938 commit e005e03
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 117 deletions.
226 changes: 124 additions & 102 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,41 +1302,59 @@ def recombinant_samples_report(self, nodes):
def _repr_html_(self):
return self.summary()._repr_html_()

# TODO fix these horrible tick labels by doing the histogram manually.
def _histogram(self, data, title, bins=None, xlabel=None, ylabel=None):
fig, ax = plt.subplots(1, 1)
fig.suptitle(title)
ax.hist(data, rwidth=0.9, bins=bins)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
return fig, [ax]

def plot_mutations_per_node_distribution(self):
nodes_with_many_muts = np.sum(self.nodes_num_mutations >= 10)
plt.title(f"Nodes with >= 10 muts: {nodes_with_many_muts}")
plt.hist(self.nodes_num_mutations, range(10), rwidth=0.9)
plt.xlabel("Number of mutations")
plt.ylabel("Number of nodes")
return self._histogram(
self.nodes_num_mutations,
title=f"Nodes with >= 10 muts: {nodes_with_many_muts}",
bins=range(10),
xlabel="Number of mutations",
ylabel="Number of nodes",
)

def plot_missing_sites_per_sample(self):
plt.title("Missing sites per sample")
plt.hist(self.nodes_num_missing_sites[self.ts.samples()], rwidth=0.9)
return self._histogram(
self.nodes_num_missing_sites[self.ts.samples()],
title="Missing sites per sample",
)

def plot_deletion_sites_per_sample(self):
plt.title("Deletion sites per sample")
plt.hist(self.nodes_num_deletion_sites[self.ts.samples()], rwidth=0.9)
return self._histogram(
self.nodes_num_deletion_sites[self.ts.samples()],
title="Deletion sites per sample",
)

def plot_branch_length_distributions(
self, log_scale=True, min_value=1, exact_match=False, max_value=400
):
fig, ax = plt.subplots(1, 1)
ts = self.ts
branch_length = ts.nodes_time[ts.edges_parent] - ts.nodes_time[ts.edges_child]
select = branch_length >= min_value
if exact_match:
select &= ts.nodes_flags[ts.edges_child] & core.NODE_IS_EXACT_MATCH > 0
plt.hist(branch_length[select], range(min_value, max_value))
plt.xlabel("Length of branches")
ax.hist(branch_length[select], range(min_value, max_value))
ax.set_xlabel("Length of branches")
if log_scale:
plt.yscale("log")
ax.set_yscale("log")
return fig, [ax]

def plot_mutations_per_site_distribution(self):
fig, ax = plt.subplots(1, 1)
sites_with_many_muts = np.sum(self.sites_num_mutations >= 10)
plt.title(f"Sites with >= 10 muts: {sites_with_many_muts}")
plt.hist(self.sites_num_mutations, range(10), rwidth=0.9)
plt.xlabel("Number of mutations")
plt.ylabel("Number of site")
ax.set_title(f"Sites with >= 10 muts: {sites_with_many_muts}")
ax.hist(self.sites_num_mutations, range(10), rwidth=0.9)
ax.set_xlabel("Number of mutations")
ax.set_ylabel("Number of site")
return fig, [ax]

def plot_mutation_spectrum(self, min_inheritors=1):
counter = self.get_mutation_spectrum(min_inheritors)
Expand All @@ -1359,10 +1377,11 @@ def plot_mutation_spectrum(self, min_inheritors=1):
step = y / 10
for key in ["C>T", "G>T"]:
rev_key = key[::-1]
ratio = counter[key] / counter[rev_key]
ratio = counter[key] / max(1, counter[rev_key]) # avoid division by zero
text = f"{key} / {rev_key}={ratio:.2f}"
y -= step
ax.text(4, y, text)
return fig, [ax]

def get_mutation_spectrum(self, min_inheritors=1):
keep = self.mutations_num_inheritors >= min_inheritors
Expand Down Expand Up @@ -1390,28 +1409,17 @@ def _add_genes_to_axis(self, ax):
ax2.set_xticks(mids, minor=False)
ax2.set_xticklabels(list(genes.keys()), rotation="vertical")

def plot_diversity(self, xlim=None):
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8))
site_div = self.ts.diversity(windows="sites", mode="site")
branch_div = self.ts.diversity(windows="sites", mode="branch")
ax1.plot(self.ts.sites_position, site_div)
ax2.plot(self.ts.sites_position, branch_div)
ax2.set_xlabel("Genome position")
ax1.set_ylabel("Site diversity")
ax2.set_ylabel("Branch diversity")
for ax in [ax1, ax2]:
self._add_genes_to_axis(ax)
if xlim is not None:
ax.set_xlim(xlim)

def plot_ts_tv_per_site(self, annotate_threshold=0.9, xlim=None):
def _wide_plot(self, *args, **kwargs):
return plt.subplots(*args, figsize=(16, 4), **kwargs)

def plot_ts_tv_per_site(self, annotate_threshold=0.9):
nonzero = self.sites_num_transversions != 0
ratio = (
self.sites_num_transitions[nonzero] / self.sites_num_transversions[nonzero]
)
pos = self.ts.sites_position[nonzero]

fig, ax = plt.subplots(1, 1, figsize=(16, 4))
fig, ax = self._wide_plot(1, 1)
ax.plot(pos, ratio)
self._add_genes_to_axis(ax)

Expand All @@ -1421,10 +1429,29 @@ def plot_ts_tv_per_site(self, annotate_threshold=0.9, xlim=None):
plt.annotate(
f"{int(pos[site])}", xy=(pos[site], ratio[site]), xycoords="data"
)
plt.ylabel("Ts/Tv")
plt.xlabel("Position on genome")
if xlim is not None:
plt.xlim(xlim)
ax.set_ylabel("Ts/Tv")
ax.set_xlabel("Position on genome")
return fig, [ax]

def _plot_per_site_count(self, count, annotate_threshold):
fig, ax = self._wide_plot(1, 1)
pos = self.ts.sites_position
ax.plot(pos, count)
self._add_genes_to_axis(ax)
threshold = np.max(count) * annotate_threshold

# Show runs of sites exceeding threshold
for v, start, length in zip(*find_runs(count > threshold)):
if v:
end = start + length
x, y = int(pos[start]), int(pos[min(self.ts.num_sites - 1, end)])
if x == y - 1:
label = f"{x}"
else:
label = f"{x}-{y}"
plt.annotate(label, xy=(x, count[start]), xycoords="data")
ax.set_xlabel("Position on genome")
return fig, ax

def plot_mutations_per_site(self, annotate_threshold=0.9, select=None):
if select is None:
Expand All @@ -1433,71 +1460,57 @@ def plot_mutations_per_site(self, annotate_threshold=0.9, select=None):
count = np.bincount(
self.ts.mutations_site[select], minlength=self.ts.num_sites
)

pos = self.ts.sites_position
fig, ax = self._plot_per_site_count(count, annotate_threshold)
zero_fraction = np.sum(count == 0) / self.ts.num_sites

fig, ax = plt.subplots(1, 1, figsize=(16, 4))
ax.plot(pos, count)
self._add_genes_to_axis(ax)
plt.annotate(
ax.annotate(
f"{zero_fraction * 100:.2f}% sites have 0 mutations",
xy=(pos[0], np.max(count)),
xy=(self.ts.sites_position[0], np.max(count)),
xycoords="data",
)
threshold = np.max(count) * annotate_threshold
top_sites = np.where(count > threshold)[0]
for site in top_sites:
plt.annotate(
f"{int(pos[site])}", xy=(pos[site], count[site]), xycoords="data"
)
plt.ylabel("Number of mutations")
plt.xlabel("Position on genome")
ax.set_ylabel("Number of mutations")
return fig, [ax]

def plot_missing_samples_per_site(self, annotate_threshold=0.5):
fig, ax = plt.subplots(1, 1, figsize=(16, 4))
self._add_genes_to_axis(ax)
count = self.sites_num_missing_samples
pos = self.ts.sites_position
ax.plot(pos, count)
threshold = np.max(count) * annotate_threshold
# Show runs of sites exceeding threshold
for v, start, length in zip(*find_runs(count > threshold)):
if v:
end = start + length
x, y = int(pos[start]), int(pos[min(self.ts.num_sites - 1, end)])
plt.annotate(f"{x}-{y}", xy=(x, count[start]), xycoords="data")

plt.ylabel("Number missing samples")
plt.xlabel("Position on genome")
fig, ax = self._plot_per_site_count(
self.sites_num_missing_samples, annotate_threshold
)
ax.set_ylabel("Number missing samples")
return fig, [ax]

def plot_deletion_samples_per_site(self, annotate_threshold=0.5):
fig, ax = plt.subplots(1, 1, figsize=(16, 4))
self._add_genes_to_axis(ax)
count = self.sites_num_deletion_samples
pos = self.ts.sites_position
ax.plot(pos, count)
threshold = np.max(count) * annotate_threshold
# Show runs of sites exceeding threshold
for v, start, length in zip(*find_runs(count > threshold)):
if v:
end = start + length
x, y = int(pos[start]), int(pos[min(self.ts.num_sites - 1, end)])
plt.annotate(f"{x}-{y}", xy=(x, count[start]), xycoords="data")
fig, ax = self._plot_per_site_count(
self.sites_num_deletion_samples, annotate_threshold
)
ax.set_ylabel("Number deletion samples")
return fig, [ax]

plt.ylabel("Number samples with deletion")
plt.xlabel("Position on genome")
def compute_deletion_overlaps(self, df_del):
ts = self.ts
overlaps = np.zeros(int(ts.sequence_length))
df_del = self.deletions_summary()
for row in df_del.itertuples():
overlaps[row.start : row.start + row.length] += 1
return overlaps[ts.sites_position.astype(int)]

def plot_deletion_overlaps(self, annotate_threshold=0.9):
df_del = self.deletions_summary()
fig, ax = self._plot_per_site_count(
self.compute_deletion_overlaps(df_del), annotate_threshold
)
ax.set_ylabel("Overlapping deletions")
return fig, [ax]

def plot_samples_per_day(self):
plt.figure(figsize=(16, 4))
fig, ax = self._wide_plot(1, 1)
t = np.arange(self.num_samples_per_day.shape[0])
plt.plot(self.time_zero_as_date - t, self.num_samples_per_day)
plt.xlabel("Date")
plt.ylabel("Number of samples")
ax.plot(self.time_zero_as_date - t, self.num_samples_per_day)
ax.set_xlabel("Date")
ax.set_ylabel("Number of samples")
return fig, [ax]

def plot_resources(self, start_date="2020-04-01"):
ts = self.ts
fig, ax = plt.subplots(2, sharex=True, figsize=(16, 8))
fig, ax = self._wide_plot(2, sharex=True)
timestamp = np.zeros(ts.num_provenances, dtype="datetime64[s]")
date = np.zeros(ts.num_provenances, dtype="datetime64[D]")
num_samples = np.zeros(ts.num_provenances, dtype=int)
Expand All @@ -1520,9 +1533,9 @@ def plot_resources(self, start_date="2020-04-01"):
ax[0].set_ylabel("Elapsed time (mins)")
ax[1].plot(date[keep], wall_time[keep] / num_samples[keep])
ax[1].set_ylabel("Elapsed time per sample (s)")
return ax
return fig, ax

def plot_recombinants_per_day(self):
def fixme_plot_recombinants_per_day(self):
counter = collections.Counter()
for u in self.recombinants:
date = np.datetime64(self.nodes_metadata[u]["date_added"])
Expand All @@ -1543,7 +1556,7 @@ def plot_recombinants_per_day(self):
ax2.set_ylabel("Fraction of samples recombinant")
ax2.set_ylim(0, 0.01)

def plot_pango_lineage_subtree(
def draw_pango_lineage_subtree(
self,
lineage,
position=None,
Expand All @@ -1556,7 +1569,7 @@ def plot_pango_lineage_subtree(
mutation_labels=None,
size=None,
style="",
**kwargs
**kwargs,
):
if position is None:
position = 21563 # pick the start of the spike
Expand All @@ -1573,19 +1586,26 @@ def plot_pango_lineage_subtree(
ts = tables.tree_sequence()
tracked_nodes = self.pango_lineage_samples[lineage]
tree = ts.at(position, tracked_samples=tracked_nodes)
order = np.array(list(tskit.drawing._postorder_tracked_minlex_traversal(
tree, collapse_tracked=collapse_tracked)))
order = np.array(
list(
tskit.drawing._postorder_tracked_minlex_traversal(
tree, collapse_tracked=collapse_tracked
)
)
)
if title is None:
simplified_ts = ts.simplify(order[np.where(ts.nodes_flags[order] & tskit.NODE_IS_SAMPLE)[0]])
simplified_ts = ts.simplify(
order[np.where(ts.nodes_flags[order] & tskit.NODE_IS_SAMPLE)[0]]
)
num_trees = simplified_ts.num_trees
tree_pos = simplified_ts.at(position).index
title = (
f"Sc2ts genealogy of {len(tracked_nodes)} {lineage} samples "
f"at position {position} (tree {tree_pos}/{num_trees})"
# f" --- file: " # TODO - show filename
)
# Find the actually shown nodes (i.e. if polytomies are packed, we may not

# Find the actually shown nodes (i.e. if polytomies are packed, we may not
# see some tips. This is copied from tskit.drawing.SvgTree.assign_x_coordinates
shown_nodes = order
if pack_untracked_polytomies:
Expand All @@ -1602,9 +1622,9 @@ def plot_pango_lineage_subtree(
shown_nodes.append(u)
else:
if len(untracked_children[u]) == 1:
# If only a single non-focal lineage, we might as well show it
for child in untracked_children[u]:
shown_nodes.append(child)
# If only a single non-focal lineage, we might as well show it
for child in untracked_children[u]:
shown_nodes.append(child)
shown_nodes.append(u)
prev = u

Expand All @@ -1626,7 +1646,9 @@ def plot_pango_lineage_subtree(
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(parent.parent).derived_state
parent_inherited_state = ts.mutation(
parent.parent
).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
# Reverse map label name to mutation id, so we can count duplicates
Expand All @@ -1641,11 +1663,11 @@ def plot_pango_lineage_subtree(
# some default styles
styles = [
"".join(f".n{u} > .sym {{fill: cyan}}" for u in tracked_nodes),
".lab.summary {font-size: 12px}",
".lab.summary {font-size: 12px}",
".polytomy {font-size: 10px}",
".mut .lab {font-size: 10px}",
".y-axis .lab {font-size: 12px}",
".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}"
".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}",
]
if len(multiple_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in multiple_mutations)
Expand Down
11 changes: 0 additions & 11 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,6 @@ def test_high_recomb_mutation(self):
self.check_double_mirror(ts)


# TODO move this to another file and test a bunch of stuff using the
# TI as a fixture
class TestTreeInfo:
def test_tree_info_values(self, fx_ts_map):
ts = fx_ts_map["2020-02-13"]
ti = sc2ts.TreeInfo(ts, show_progress=False)
assert list(ti.nodes_num_missing_sites[:5]) == [0, 0, 0, 560, 535]
assert list(ti.sites_num_missing_samples[:5]) == [4, 4, 4, 4, 4]
assert list(ti.sites_num_deletion_samples[:5]) == [0, 0, 0, 0, 0]


class TestRealData:
dates = [
"2020-01-01",
Expand Down
Loading

0 comments on commit e005e03

Please sign in to comment.