Skip to content

Commit

Permalink
Merge pull request jeromekelleher#304 from hyanwong/svg-improvements
Browse files Browse the repository at this point in the history
Prettify sample_group draw_svg()
  • Loading branch information
jeromekelleher authored Sep 25, 2024
2 parents 8e08ed8 + 15e3f07 commit aecd697
Showing 1 changed file with 120 additions and 12 deletions.
132 changes: 120 additions & 12 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
import dataclasses
import datetime
import re
from typing import List

import numba
Expand Down Expand Up @@ -1439,32 +1440,139 @@ class SampleGroupInfo:
ts: tskit.TreeSequence
attach_date: None

def draw_svg(self, size=(800, 600), time_scale=None):
def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_labels=None, style=None, highlight_universal_mutations=None, x_regions=None, **kwargs):
"""
Draw an SVG representation of the tree of samples that trace to a single origin.
The default style is to colour mutations such that sites with a single
mutation in the tree are dark red, whereas sites with multiple mutations
show those mutations in red or magenta (magenta when a mutation immediately
reverts its parent mutation). Any identical mutations (from the same inherited
to derived state at the same site, i.e. recurrent mutations) have the count of
recurrent mutations appended to the label, e.g. "C842T (1/2)".
If highlight_universal_mutations is set, then mutations in the ancestry of all
the samples (i.e. between the root and the MRCA of all the samples) are highlighted
in bold and with thicker symbol lines
If genetic_regions is set, it should be a dictionary mapping gene names to
(start, end) tuples. These will be drawn as coloured rectangles on the x-axis. If None,
a default selection of SARS-CoV-2 genes will be used.
"""
if x_regions is None:
x_regions = {
"ORF1a": (266, 13468),
"ORF1b": (13468, 21555),
"Spike": (21563, 25384),
}


if style is None:
style = ""
ts = self.ts
assert ts.num_trees == 1
y_ticks = {
ts.nodes_time[u]: ts.node(u).metadata["date"] for u in list(ts.samples())
}
y_ticks[ts.nodes_time[ts.first().root]] = self.attach_date
if time_scale == "rank":
times = list(np.unique(ts.nodes_time))
y_ticks = {times.index(k): v for k, v in y_ticks.items()}

mut_labels = {}
for site in ts.sites():
# TODO Viz the recurrent mutations
for mut in site.mutations:
mut_labels[mut.id] = (
f"{site.ancestral_state}{int(site.position)}{mut.derived_state}"
)

return self.ts.draw_svg(
shared_nodes = []
if highlight_universal_mutations is not None:
# find edges above
tree = ts.first()
shared_nodes = [tree.root]
while tree.num_children(shared_nodes[-1]) == 1:
shared_nodes.append(tree.children(shared_nodes[-1])[0])

multiple_mutations = []
universal_mutations = []
reverted_mutations = []
if mutation_labels is None:
mutation_labels = collections.defaultdict(list)
for site in ts.sites():
# TODO Viz the recurrent mutations
for mut in site.mutations:
if mut.node in shared_nodes:
universal_mutations.append(mut.id)
if len(site.mutations) > 1:
multiple_mutations.append(mut.id)
inherited_state = site.ancestral_state
if mut.parent >= 0:
parent = ts.mutation(mut.parent)
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(parent.parent).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
# Reverse map label name to mutation id, so we can count duplicates
label = f"{inherited_state}{int(site.position)}{mut.derived_state}"
mutation_labels[label].append(mut.id)
# If more than one mutation has the same label, add a prefix with the counts
mutation_labels = {
m_id: label + (f" ({i+1}/{len(ids)})" if len(ids) > 1 else "")
for label, ids in mutation_labels.items() for i, m_id in enumerate(ids)}
# some default styles
styles = [".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}"]
if len(multiple_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in multiple_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in multiple_mutations)
styles.append(lab_css + "{fill: red}" + sym_css + "{stroke: red}")
if len(reverted_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in reverted_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in reverted_mutations)
styles.append(lab_css + "{fill: magenta}" + sym_css + "{stroke: magenta}")
if len(universal_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in universal_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in universal_mutations)
sym_ax_css = ", ".join(f".x-axis .mut.m{m} .sym" for m in universal_mutations)
styles.append(lab_css + "{font-weight: bold}" + sym_css + "{stroke-width: 3}")
styles.append(sym_ax_css + "{stroke-width: 8}")
svg = self.ts.draw_svg(
size=size,
time_scale=time_scale,
y_axis=True,
mutation_labels=mut_labels,
mutation_labels=mutation_labels,
y_ticks=y_ticks,
style="".join(styles) + style,
**kwargs,
)

# Hack to add genes to the X axis
if len(x_regions) > 0:
assert svg.startswith("<svg")
header = svg[:svg.find(">") + 1]
footer = "</svg>"

# Find SVG positions of the X axis
m = re.search(r'class="x-axis".*?class="ax-line" x1="([\d\.]+)" x2="([\d\.]+)" y1="([\d\.]+)"', svg)
assert m is not None
x1, x2, y1 = float(m.group(1)), float(m.group(2)), float(m.group(3))
xdiff = x2 - x1
x_box_svg = '<rect fill="yellow" stroke="black" x="{x}" width="{w}" y="{y}" height="{h}" />'
x_name_svg = '<text text-anchor="middle" alignment-baseline="hanging" x="{x}" y="{y}">{name}</text>'
x_scale = xdiff / ts.sequence_length
x_boxes = [
x_box_svg.format(
x=x1 + p1 * x_scale,
w=(p2-p1) * x_scale,
y=y1,
h=20) # height of the box: hardcoded for now to match font height
for p1, p2 in x_regions.values()
]
x_names = [
x_name_svg.format(x=x1 + (p[0] + p[1])/2 * x_scale, y=y1+2, name=name)
for name, p in x_regions.items()
]
# add the new SVG to the old
svg = (header + "".join(x_boxes) + "".join(x_names) + footer) + svg
# Now wrap both in another SVG
svg = header + svg + footer

return tskit.drawing.SVGString(svg)

def get_sample_metadata(self, key):
ret = []
for u in self.ts.samples():
Expand Down

0 comments on commit aecd697

Please sign in to comment.