diff --git a/sc2ts/info.py b/sc2ts/info.py index 794e41a..0a9f25a 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -1302,41 +1302,59 @@ def recombinant_samples_report(self, nodes): def _repr_html_(self): return self.summary()._repr_html_() - # TODO fix these horrible tick labels by doing the histogram manually. + def _histogram(self, data, title, bins=None, xlabel=None, ylabel=None): + fig, ax = plt.subplots(1, 1) + fig.suptitle(title) + ax.hist(data, rwidth=0.9, bins=bins) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + return fig, [ax] + def plot_mutations_per_node_distribution(self): nodes_with_many_muts = np.sum(self.nodes_num_mutations >= 10) - plt.title(f"Nodes with >= 10 muts: {nodes_with_many_muts}") - plt.hist(self.nodes_num_mutations, range(10), rwidth=0.9) - plt.xlabel("Number of mutations") - plt.ylabel("Number of nodes") + return self._histogram( + self.nodes_num_mutations, + title=f"Nodes with >= 10 muts: {nodes_with_many_muts}", + bins=range(10), + xlabel="Number of mutations", + ylabel="Number of nodes", + ) def plot_missing_sites_per_sample(self): - plt.title("Missing sites per sample") - plt.hist(self.nodes_num_missing_sites[self.ts.samples()], rwidth=0.9) + return self._histogram( + self.nodes_num_missing_sites[self.ts.samples()], + title="Missing sites per sample", + ) def plot_deletion_sites_per_sample(self): - plt.title("Deletion sites per sample") - plt.hist(self.nodes_num_deletion_sites[self.ts.samples()], rwidth=0.9) + return self._histogram( + self.nodes_num_deletion_sites[self.ts.samples()], + title="Deletion sites per sample", + ) def plot_branch_length_distributions( self, log_scale=True, min_value=1, exact_match=False, max_value=400 ): + fig, ax = plt.subplots(1, 1) ts = self.ts branch_length = ts.nodes_time[ts.edges_parent] - ts.nodes_time[ts.edges_child] select = branch_length >= min_value if exact_match: select &= ts.nodes_flags[ts.edges_child] & core.NODE_IS_EXACT_MATCH > 0 - plt.hist(branch_length[select], range(min_value, max_value)) - plt.xlabel("Length of branches") + ax.hist(branch_length[select], range(min_value, max_value)) + ax.set_xlabel("Length of branches") if log_scale: - plt.yscale("log") + ax.set_yscale("log") + return fig, [ax] def plot_mutations_per_site_distribution(self): + fig, ax = plt.subplots(1, 1) sites_with_many_muts = np.sum(self.sites_num_mutations >= 10) - plt.title(f"Sites with >= 10 muts: {sites_with_many_muts}") - plt.hist(self.sites_num_mutations, range(10), rwidth=0.9) - plt.xlabel("Number of mutations") - plt.ylabel("Number of site") + ax.set_title(f"Sites with >= 10 muts: {sites_with_many_muts}") + ax.hist(self.sites_num_mutations, range(10), rwidth=0.9) + ax.set_xlabel("Number of mutations") + ax.set_ylabel("Number of site") + return fig, [ax] def plot_mutation_spectrum(self, min_inheritors=1): counter = self.get_mutation_spectrum(min_inheritors) @@ -1359,10 +1377,11 @@ def plot_mutation_spectrum(self, min_inheritors=1): step = y / 10 for key in ["C>T", "G>T"]: rev_key = key[::-1] - ratio = counter[key] / counter[rev_key] + ratio = counter[key] / max(1, counter[rev_key]) # avoid division by zero text = f"{key} / {rev_key}={ratio:.2f}" y -= step ax.text(4, y, text) + return fig, [ax] def get_mutation_spectrum(self, min_inheritors=1): keep = self.mutations_num_inheritors >= min_inheritors @@ -1390,28 +1409,17 @@ def _add_genes_to_axis(self, ax): ax2.set_xticks(mids, minor=False) ax2.set_xticklabels(list(genes.keys()), rotation="vertical") - def plot_diversity(self, xlim=None): - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8)) - site_div = self.ts.diversity(windows="sites", mode="site") - branch_div = self.ts.diversity(windows="sites", mode="branch") - ax1.plot(self.ts.sites_position, site_div) - ax2.plot(self.ts.sites_position, branch_div) - ax2.set_xlabel("Genome position") - ax1.set_ylabel("Site diversity") - ax2.set_ylabel("Branch diversity") - for ax in [ax1, ax2]: - self._add_genes_to_axis(ax) - if xlim is not None: - ax.set_xlim(xlim) - - def plot_ts_tv_per_site(self, annotate_threshold=0.9, xlim=None): + def _wide_plot(self, *args, **kwargs): + return plt.subplots(*args, figsize=(16, 4), **kwargs) + + def plot_ts_tv_per_site(self, annotate_threshold=0.9): nonzero = self.sites_num_transversions != 0 ratio = ( self.sites_num_transitions[nonzero] / self.sites_num_transversions[nonzero] ) pos = self.ts.sites_position[nonzero] - fig, ax = plt.subplots(1, 1, figsize=(16, 4)) + fig, ax = self._wide_plot(1, 1) ax.plot(pos, ratio) self._add_genes_to_axis(ax) @@ -1421,10 +1429,29 @@ def plot_ts_tv_per_site(self, annotate_threshold=0.9, xlim=None): plt.annotate( f"{int(pos[site])}", xy=(pos[site], ratio[site]), xycoords="data" ) - plt.ylabel("Ts/Tv") - plt.xlabel("Position on genome") - if xlim is not None: - plt.xlim(xlim) + ax.set_ylabel("Ts/Tv") + ax.set_xlabel("Position on genome") + return fig, [ax] + + def _plot_per_site_count(self, count, annotate_threshold): + fig, ax = self._wide_plot(1, 1) + pos = self.ts.sites_position + ax.plot(pos, count) + self._add_genes_to_axis(ax) + threshold = np.max(count) * annotate_threshold + + # Show runs of sites exceeding threshold + for v, start, length in zip(*find_runs(count > threshold)): + if v: + end = start + length + x, y = int(pos[start]), int(pos[min(self.ts.num_sites - 1, end)]) + if x == y - 1: + label = f"{x}" + else: + label = f"{x}-{y}" + plt.annotate(label, xy=(x, count[start]), xycoords="data") + ax.set_xlabel("Position on genome") + return fig, ax def plot_mutations_per_site(self, annotate_threshold=0.9, select=None): if select is None: @@ -1433,71 +1460,57 @@ def plot_mutations_per_site(self, annotate_threshold=0.9, select=None): count = np.bincount( self.ts.mutations_site[select], minlength=self.ts.num_sites ) - - pos = self.ts.sites_position + fig, ax = self._plot_per_site_count(count, annotate_threshold) zero_fraction = np.sum(count == 0) / self.ts.num_sites - - fig, ax = plt.subplots(1, 1, figsize=(16, 4)) - ax.plot(pos, count) - self._add_genes_to_axis(ax) - plt.annotate( + ax.annotate( f"{zero_fraction * 100:.2f}% sites have 0 mutations", - xy=(pos[0], np.max(count)), + xy=(self.ts.sites_position[0], np.max(count)), xycoords="data", ) - threshold = np.max(count) * annotate_threshold - top_sites = np.where(count > threshold)[0] - for site in top_sites: - plt.annotate( - f"{int(pos[site])}", xy=(pos[site], count[site]), xycoords="data" - ) - plt.ylabel("Number of mutations") - plt.xlabel("Position on genome") + ax.set_ylabel("Number of mutations") + return fig, [ax] def plot_missing_samples_per_site(self, annotate_threshold=0.5): - fig, ax = plt.subplots(1, 1, figsize=(16, 4)) - self._add_genes_to_axis(ax) - count = self.sites_num_missing_samples - pos = self.ts.sites_position - ax.plot(pos, count) - threshold = np.max(count) * annotate_threshold - # Show runs of sites exceeding threshold - for v, start, length in zip(*find_runs(count > threshold)): - if v: - end = start + length - x, y = int(pos[start]), int(pos[min(self.ts.num_sites - 1, end)]) - plt.annotate(f"{x}-{y}", xy=(x, count[start]), xycoords="data") - - plt.ylabel("Number missing samples") - plt.xlabel("Position on genome") + fig, ax = self._plot_per_site_count( + self.sites_num_missing_samples, annotate_threshold + ) + ax.set_ylabel("Number missing samples") + return fig, [ax] def plot_deletion_samples_per_site(self, annotate_threshold=0.5): - fig, ax = plt.subplots(1, 1, figsize=(16, 4)) - self._add_genes_to_axis(ax) - count = self.sites_num_deletion_samples - pos = self.ts.sites_position - ax.plot(pos, count) - threshold = np.max(count) * annotate_threshold - # Show runs of sites exceeding threshold - for v, start, length in zip(*find_runs(count > threshold)): - if v: - end = start + length - x, y = int(pos[start]), int(pos[min(self.ts.num_sites - 1, end)]) - plt.annotate(f"{x}-{y}", xy=(x, count[start]), xycoords="data") + fig, ax = self._plot_per_site_count( + self.sites_num_deletion_samples, annotate_threshold + ) + ax.set_ylabel("Number deletion samples") + return fig, [ax] - plt.ylabel("Number samples with deletion") - plt.xlabel("Position on genome") + def compute_deletion_overlaps(self, df_del): + ts = self.ts + overlaps = np.zeros(int(ts.sequence_length)) + df_del = self.deletions_summary() + for row in df_del.itertuples(): + overlaps[row.start : row.start + row.length] += 1 + return overlaps[ts.sites_position.astype(int)] + + def plot_deletion_overlaps(self, annotate_threshold=0.9): + df_del = self.deletions_summary() + fig, ax = self._plot_per_site_count( + self.compute_deletion_overlaps(df_del), annotate_threshold + ) + ax.set_ylabel("Overlapping deletions") + return fig, [ax] def plot_samples_per_day(self): - plt.figure(figsize=(16, 4)) + fig, ax = self._wide_plot(1, 1) t = np.arange(self.num_samples_per_day.shape[0]) - plt.plot(self.time_zero_as_date - t, self.num_samples_per_day) - plt.xlabel("Date") - plt.ylabel("Number of samples") + ax.plot(self.time_zero_as_date - t, self.num_samples_per_day) + ax.set_xlabel("Date") + ax.set_ylabel("Number of samples") + return fig, [ax] def plot_resources(self, start_date="2020-04-01"): ts = self.ts - fig, ax = plt.subplots(2, sharex=True, figsize=(16, 8)) + fig, ax = self._wide_plot(2, sharex=True) timestamp = np.zeros(ts.num_provenances, dtype="datetime64[s]") date = np.zeros(ts.num_provenances, dtype="datetime64[D]") num_samples = np.zeros(ts.num_provenances, dtype=int) @@ -1520,9 +1533,9 @@ def plot_resources(self, start_date="2020-04-01"): ax[0].set_ylabel("Elapsed time (mins)") ax[1].plot(date[keep], wall_time[keep] / num_samples[keep]) ax[1].set_ylabel("Elapsed time per sample (s)") - return ax + return fig, ax - def plot_recombinants_per_day(self): + def fixme_plot_recombinants_per_day(self): counter = collections.Counter() for u in self.recombinants: date = np.datetime64(self.nodes_metadata[u]["date_added"]) @@ -1543,7 +1556,7 @@ def plot_recombinants_per_day(self): ax2.set_ylabel("Fraction of samples recombinant") ax2.set_ylim(0, 0.01) - def plot_pango_lineage_subtree( + def draw_pango_lineage_subtree( self, lineage, position=None, @@ -1556,7 +1569,7 @@ def plot_pango_lineage_subtree( mutation_labels=None, size=None, style="", - **kwargs + **kwargs, ): if position is None: position = 21563 # pick the start of the spike @@ -1573,10 +1586,17 @@ def plot_pango_lineage_subtree( ts = tables.tree_sequence() tracked_nodes = self.pango_lineage_samples[lineage] tree = ts.at(position, tracked_samples=tracked_nodes) - order = np.array(list(tskit.drawing._postorder_tracked_minlex_traversal( - tree, collapse_tracked=collapse_tracked))) + order = np.array( + list( + tskit.drawing._postorder_tracked_minlex_traversal( + tree, collapse_tracked=collapse_tracked + ) + ) + ) if title is None: - simplified_ts = ts.simplify(order[np.where(ts.nodes_flags[order] & tskit.NODE_IS_SAMPLE)[0]]) + simplified_ts = ts.simplify( + order[np.where(ts.nodes_flags[order] & tskit.NODE_IS_SAMPLE)[0]] + ) num_trees = simplified_ts.num_trees tree_pos = simplified_ts.at(position).index title = ( @@ -1584,8 +1604,8 @@ def plot_pango_lineage_subtree( f"at position {position} (tree {tree_pos}/{num_trees})" # f" --- file: " # TODO - show filename ) - - # Find the actually shown nodes (i.e. if polytomies are packed, we may not + + # Find the actually shown nodes (i.e. if polytomies are packed, we may not # see some tips. This is copied from tskit.drawing.SvgTree.assign_x_coordinates shown_nodes = order if pack_untracked_polytomies: @@ -1602,9 +1622,9 @@ def plot_pango_lineage_subtree( shown_nodes.append(u) else: if len(untracked_children[u]) == 1: - # If only a single non-focal lineage, we might as well show it - for child in untracked_children[u]: - shown_nodes.append(child) + # If only a single non-focal lineage, we might as well show it + for child in untracked_children[u]: + shown_nodes.append(child) shown_nodes.append(u) prev = u @@ -1626,7 +1646,9 @@ def plot_pango_lineage_subtree( inherited_state = parent.derived_state parent_inherited_state = site.ancestral_state if parent.parent >= 0: - parent_inherited_state = ts.mutation(parent.parent).derived_state + parent_inherited_state = ts.mutation( + parent.parent + ).derived_state if parent_inherited_state == mut.derived_state: reverted_mutations.append(mut.id) # Reverse map label name to mutation id, so we can count duplicates @@ -1641,11 +1663,11 @@ def plot_pango_lineage_subtree( # some default styles styles = [ "".join(f".n{u} > .sym {{fill: cyan}}" for u in tracked_nodes), - ".lab.summary {font-size: 12px}", + ".lab.summary {font-size: 12px}", ".polytomy {font-size: 10px}", ".mut .lab {font-size: 10px}", ".y-axis .lab {font-size: 12px}", - ".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}" + ".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}", ] if len(multiple_mutations) > 0: lab_css = ", ".join(f".mut.m{m} .lab" for m in multiple_mutations) diff --git a/tests/test_inference.py b/tests/test_inference.py index 73229af..5b54f02 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -224,17 +224,6 @@ def test_high_recomb_mutation(self): self.check_double_mirror(ts) -# TODO move this to another file and test a bunch of stuff using the -# TI as a fixture -class TestTreeInfo: - def test_tree_info_values(self, fx_ts_map): - ts = fx_ts_map["2020-02-13"] - ti = sc2ts.TreeInfo(ts, show_progress=False) - assert list(ti.nodes_num_missing_sites[:5]) == [0, 0, 0, 560, 535] - assert list(ti.sites_num_missing_samples[:5]) == [4, 4, 4, 4, 4] - assert list(ti.sites_num_deletion_samples[:5]) == [0, 0, 0, 0, 0] - - class TestRealData: dates = [ "2020-01-01", diff --git a/tests/test_info.py b/tests/test_info.py index 8c0a2a9..f9198ac 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -1,6 +1,9 @@ +import inspect + import pytest import numpy as np import pandas as pd +import matplotlib import msprime import tskit @@ -8,6 +11,11 @@ from sc2ts import info +@pytest.fixture +def fx_ti_2020_02_13(fx_ts_map): + ts = fx_ts_map["2020-02-13"] + return info.TreeInfo(ts, show_progress=False) + class TestTallyLineages: @@ -93,7 +101,7 @@ def test_1tree_2mut_reversion(self): def test_2trees_0mut(self): ts = msprime.sim_ancestry( 2, - recombination_rate=1e6, # Nearly guarantee recomb. + recombination_rate=1e6, # Nearly guarantee recomb. sequence_length=2, ) assert ts.num_trees == 2 @@ -105,7 +113,7 @@ def test_2trees_1mut(self): ts = msprime.sim_ancestry( 4, ploidy=1, - recombination_rate=1e6, # Nearly guarantee recomb. + recombination_rate=1e6, # Nearly guarantee recomb. sequence_length=2, ) tables = ts.dump_tables() @@ -122,7 +130,7 @@ def test_2trees_2mut_diff_trees(self): ts = msprime.sim_ancestry( 4, ploidy=1, - recombination_rate=1e6, # Nearly guarantee recomb. + recombination_rate=1e6, # Nearly guarantee recomb. sequence_length=2, ) tables = ts.dump_tables() @@ -141,7 +149,7 @@ def test_2trees_2mut_same_tree(self): ts = msprime.sim_ancestry( 4, ploidy=1, - recombination_rate=1e6, # Nearly guarantee recomb. + recombination_rate=1e6, # Nearly guarantee recomb. sequence_length=2, ) tables = ts.dump_tables() @@ -156,3 +164,25 @@ def test_2trees_2mut_same_tree(self): expected[3] = 1 actual = info.get_num_muts(ts) np.testing.assert_equal(expected, actual) + + +class TestTreeInfo: + def test_tree_info_values(self, fx_ti_2020_02_13): + ti = fx_ti_2020_02_13 + assert list(ti.nodes_num_missing_sites[:5]) == [0, 0, 0, 560, 535] + assert list(ti.sites_num_missing_samples[:5]) == [4, 4, 4, 4, 4] + assert list(ti.sites_num_deletion_samples[:5]) == [0, 0, 0, 0, 0] + + @pytest.mark.parametrize( + "method", + [ + func + for (name, func) in inspect.getmembers(info.TreeInfo) + if name.startswith("plot") + ], + ) + def test_plots(self, fx_ti_2020_02_13, method): + fig, axes = method(fx_ti_2020_02_13) + assert isinstance(fig, matplotlib.figure.Figure) + for ax in axes: + assert isinstance(ax, matplotlib.axes.Axes)