Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow internal/non-contemporary samples for dating sc2ts, pedigrees #433

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

nspope
Copy link
Contributor

@nspope nspope commented Nov 1, 2024

A minimal case:

Mutations: 16

--- TRUE TREES ---
13.00┊               ┊             ┊             ┊       16    ┊
     ┊               ┊             ┊             ┊     ┏━━┻━━┓ ┊
12.001515     ┊             ┊     ┃     ┃ ┊
     ┊ ┏━┻━━┓        ┊   ┏━━━┻━━┓  ┊             ┊     ┃     ┃ ┊
8.00 ┊ ┃    ┃        ┊  13      ┃  ┊    1314     ┃ ┊
     ┊ ┃    ┃        ┊ ┏━┻━┓    ┃  ┊ ┏━━┳┻━━━┓   ┊  ┏━━┻━┓   ┃ ┊
7.00 ┊ ┃   11        ┊ ┃   ┃   11  ┊ ┃ 11    ┃   ┊ 12    ┃   ┃ ┊
     ┊ ┃ ┏━┳┻━━┓     ┊ ┃   ┃   ┏┻┓ ┊ ┃ ┏┻┓   ┃   ┊ ┏┻┓   ┃   ┃ ┊
6.001 2 3   02   1   3 02 3 0   12 0   1   3 ┊
     ┊         ┃     ┊     ┃     ┃ ┊     ┃   ┃   ┊   ┃   ┃     ┊
3.0010     ┊     ┃     ┃ ┊     ┃   ┃   ┊   ┃   ┃     ┊
     ┊       ┏━┻━┓   ┊     ┃     ┃ ┊     ┃   ┃   ┊   ┃   ┃     ┊
2.00 ┊       ┃   ┃   ┊     ┃     ┃ ┊     ┃   9   ┊   ┃   9     ┊
     ┊       ┃   ┃   ┊     ┃     ┃ ┊     ┃ ┏━┻┓  ┊   ┃  ┏┻━┓   ┊
1.00 ┊       ┃   88     ┃ ┊     ┃ ┃  8  ┊   ┃  8  ┃   ┊
     ┊       ┃ ┏━╋━┓ ┊   ┏━╋━┓   ┃ ┊     ┃ ┃ ┏┻┓ ┊   ┃ ┏┻┓ ┃   ┊
0.004 5 6 75 6 7   44 5 6 74 5 7 60               1             2             3             4

--- DATED WITH INTERNAL SAMPLES ---
15.85┊               ┊             ┊             ┊       16    ┊
     ┊               ┊             ┊             ┊     ┏━━┻━━┓ ┊
15.471515     ┊             ┊     ┃     ┃ ┊
     ┊ ┏━┻━━┓        ┊   ┏━━━┻━━┓  ┊             ┊     ┃     ┃ ┊
10.35┊ ┃    ┃        ┊  13      ┃  ┊    13       ┊     ┃     ┃ ┊
     ┊ ┃    ┃        ┊ ┏━┻━┓    ┃  ┊ ┏━━┳┻━━━┓   ┊     ┃     ┃ ┊
9.33 ┊ ┃    ┃        ┊ ┃   ┃    ┃  ┊ ┃  ┃    ┃   ┊    14     ┃ ┊
     ┊ ┃    ┃        ┊ ┃   ┃    ┃  ┊ ┃  ┃    ┃   ┊  ┏━━┻━┓   ┃ ┊
8.19 ┊ ┃   11        ┊ ┃   ┃   11  ┊ ┃ 11    ┃   ┊  ┃    ┃   ┃ ┊
     ┊ ┃ ┏━┳┻━━┓     ┊ ┃   ┃   ┏┻┓ ┊ ┃ ┏┻┓   ┃   ┊  ┃    ┃   ┃ ┊
6.95 ┊ ┃ ┃ ┃   ┃     ┊ ┃   ┃   ┃ ┃ ┊ ┃ ┃ ┃   ┃   ┊ 12    ┃   ┃ ┊
     ┊ ┃ ┃ ┃   ┃     ┊ ┃   ┃   ┃ ┃ ┊ ┃ ┃ ┃   ┃   ┊ ┏┻┓   ┃   ┃ ┊
6.001 2 3   02   1   3 02 3 0   12 0   1   3 ┊
     ┊         ┃     ┊     ┃     ┃ ┊     ┃   ┃   ┊   ┃   ┃     ┊
2.7210     ┊     ┃     ┃ ┊     ┃   ┃   ┊   ┃   ┃     ┊
     ┊       ┏━┻━┓   ┊     ┃     ┃ ┊     ┃   ┃   ┊   ┃   ┃     ┊
2.11 ┊       ┃   ┃   ┊     ┃     ┃ ┊     ┃   9   ┊   ┃   9     ┊
     ┊       ┃   ┃   ┊     ┃     ┃ ┊     ┃ ┏━┻┓  ┊   ┃  ┏┻━┓   ┊
0.98 ┊       ┃   88     ┃ ┊     ┃ ┃  8  ┊   ┃  8  ┃   ┊
     ┊       ┃ ┏━╋━┓ ┊   ┏━╋━┓   ┃ ┊     ┃ ┃ ┏┻┓ ┊   ┃ ┏┻┓ ┃   ┊
0.004 5 6 75 6 7   44 5 6 74 5 7 60               1             2             3             4
Code to simulate and run on test case
import msprime
import numpy as np
import tskit


def simulate_internal_samples(pop_size=2, sample_time=6, L=4, r=0.1, random_seed=1):
    """ 
    Simulate ts with noncontemporary samples that have descendants 
    https://github.com/tskit-dev/msprime/discussions/2260
    """
    rng = np.random.default_rng(random_seed)
    
    samples = [
        msprime.SampleSet(pop_size),
        msprime.SampleSet(1, ploidy=1, time=sample_time + 1),
    ]
    
    # recent history up to sample time
    ts1 = msprime.sim_ancestry(
        samples=samples,
        population_size=pop_size,
        model="dtwf",
        end_time=sample_time,
        sequence_length=L,
        recombination_rate=r,
        random_seed=random_seed,
    )
    
    # remove dummy node above the sample time
    ts1 = ts1.simplify(range(2 * pop_size), keep_unary=True)
    
    # history before sample time
    ts2 = msprime.sim_ancestry(
        samples=[msprime.SampleSet(pop_size, time=sample_time)],
        model="dtwf",
        population_size=pop_size,
        sequence_length=L,
        recombination_rate=r,
        random_seed=random_seed + 1000,
    )
    
    # remap roots to samples in ts2
    roots = [n.id for n in ts1.nodes() if n.time == sample_time]
    tips = ts2.samples()
    rng.shuffle(tips)
    node_mapping = [tskit.NULL for _ in ts1.nodes()]
    for t, n in zip(tips[:len(roots)], roots):
        node_mapping[n] = t
    ts = ts2.union(ts1, node_mapping, check_shared_equality=False)

    return ts


# ---- some minimal test cases --- #

import tsdate

def get_internal_samples(ts):
    always_internal = np.full(ts.num_nodes, False)
    always_internal[list(ts.samples())] = True
    for t in ts.trees():
        for n in t.samples():
            if t.num_children(n) == 0:
                always_internal[n] = False
    return np.flatnonzero(always_internal)

mu = 0.1
ts = simulate_internal_samples(random_seed=1).simplify()
ts = msprime.sim_mutations(ts, rate=mu, random_seed=1)
print(f"Mutations: {ts.num_mutations}")
print(f"--- TRUE TREES ---\n", ts.draw_text())

# (1) case with internal-only samples (may be unary) and noncontemporary leaf samples
# NB: sample ages can be modified, because of how we force positive branch lengths.
#     this'll be fixed (or a least minimized) by setting `constr_iterations`
foo = tsdate.date(ts, mutation_rate=mu, rescaling_intervals=1, constr_iterations=100)
assert np.allclose(foo.nodes_time[list(foo.samples())], ts.nodes_time[list(foo.samples())])
print(f"--- DATED WITH INTERNAL SAMPLES ---\n", foo.draw_text())

# (2) when there's a unary node that's not a sample, we should fail
internal_samples = get_internal_samples(ts)
tab = ts.dump_tables()
new_flags = tab.nodes.flags.copy()
new_flags[internal_samples] = 0
tab.nodes.flags = new_flags
ts_no_internal = tab.tree_sequence()
try:
    bar = tsdate.date(ts_no_internal, mutation_rate=mu, rescaling_intervals=1, constr_iterations=100)
except ValueError:
    pass

# get rid of unary node, and it'll work
bar = tsdate.date(ts_no_internal.simplify(), mutation_rate=mu, rescaling_intervals=1, constr_iterations=100)

Some caveats:

  • Internal samples can be unary over part or all of their span; but we'll still error out if non-sample nodes are unary
  • Because of how we're forcing the positive branch length constraint, ages of internal samples (e.g. with descendants) can be pushed upward. This can be fixed (or minimized) by setting constr_iterations to 100 or something similar.
  • Rescaling is done as usual, but without changing the ages of sample nodes. This might cause issues with nodes "bunching up" if e.g. the mutation rate is misspecified or the sample age is wrong or there's other weirdness with the inputs. There's better ways to do this, but it'll take a bit more work to get something that's reliable.
  • I've only tried this out on minimal examples, and not inferred TS's, as I'm not quite clear on how to add internal samples in this case (maybe it's straightforward). @hyanwong maybe you could try this out on sc2ts?

@hyanwong
Copy link
Member

hyanwong commented Nov 1, 2024

This is great!

I'm just thinking about how I might get a sensible mutation rate to use. I presume one way, if we think that most of the node times are +- correct, is to ensure that the TS is simplified, then simply count up the total mutational area and divide by the number of mutations, i.e.

rate = ts.num_mutations / ((ts.edges_right - ts.edges_left) * (ts.nodes_time[ts.edges_parent] - ts.nodes_time[ts.edges_child])).sum()

@hyanwong
Copy link
Member

hyanwong commented Nov 1, 2024

An initial run on a huge covid arg gives:

File [~/Library/jupyterlab-desktop/jlab_server/lib/python3.12/site-packages/tsdate/variational.py:693](http://localhost:59486/lab/tree/~/Library/jupyterlab-desktop/jlab_server/lib/python3.12/site-packages/tsdate/variational.py#line=692), in ExpectationPropagation.iterate(self, max_shape, min_step, em_maxitt, em_reltol, regularise, check_valid)
    691 # exponential regularization on roots
    692 if regularise:
--> 693     self.propagate_prior(
    694         self.unconstrained_roots,
    695         self.node_posterior,
    696         self.node_factors,
    697         self.node_scale,
    698         max_shape,
    699         em_maxitt,
    700         em_reltol,
    701     )
    703 # absorb the scaling term into the factors
    704 self.rescale_factors(
    705     self.edge_parents,
    706     self.edge_children,
   (...)
    711     self.node_scale,
    712 )

ZeroDivisionError: division by zero

I assume there is some deeper number stuff going on: is there a way to switch number off to get a deeper trace?

@hyanwong
Copy link
Member

hyanwong commented Nov 1, 2024

Probably the reason for this is that there are very few non-sample nodes, and the root is a sample node. For example, if I simplify to the first 50 nodes, I get a tree sequence consisting of a single tree like this
Screenshot 2024-11-01 at 22 21 26

small_ts.trees.zip

@nspope
Copy link
Contributor Author

nspope commented Nov 1, 2024

In several places in the code, I'm assuming the time scale starts at zero, but it seems this is never checked. I suppose we should enforce this. Can you try it on the big covid sequence with (a) shifting nodes time so that the most recent sample node is at 0, (b) turning rescaling off with rescaling_intervals=0? This allows that tiny example to run through.

@hyanwong
Copy link
Member

hyanwong commented Nov 1, 2024

Good point. I have been using rescaling_intervals=0 for this. When I transform the times to start at 0, and flag the root(s) as not-a-sample, I can get dating to work with pretty large subsets of the data (e.g. 350k samples), which is fantastic.

It still fails on larger sample sizes though, but now with a slightly different (assertion) error:

File [~/Library/jupyterlab-desktop/jlab_server/lib/python3.12/site-packages/tsdate/variational.py:89](http://localhost:59486/lab/tree/~/Library/jupyterlab-desktop/jlab_server/lib/python3.12/site-packages/tsdate/variational.py#line=88), in _damp()
     87 b = 1.0 if (x[1] - y[1] > x[1] * s) else (1 - s) * x[1] [/](http://localhost:59486/) y[1]
     88 d = min(a, b)
---> 89 assert 0.0 < d <= 1.0
     90 return d

AssertionError:

@nspope
Copy link
Contributor Author

nspope commented Nov 1, 2024

This is probably due to floating point error, in that comparison against 1. Seems like like this a good case to catch these sorts of numerical issues. Can you share the trees over slack or email, and the code you're using to prepare them?

@jeromekelleher
Copy link
Member

Amazing!

@hyanwong
Copy link
Member

hyanwong commented Nov 4, 2024

Just to post some of @nspope's thoughts. The most recent branch fixes the sample-root issue and also allows rescaling, but it's not clear to Nate or myself that rescaling is sensible here anyway. Internal samples seems to be dealt with just fine. Re the 0.0 < d <= 1.0 assertion fail he says:

It does seem like we're hitting against similar issues with these trees as we were when trying to date TS's with unary nodes -- there's lots of constraints, so the gamma shape parameters go to the limit. I'd suggest comparing dates with different numbers of EP iterations to see if convergence is really slow. If it is, then rescaling might help.

We can use the current code to date relatively large portions of the sc2ts ARGs (e.g. >400k samples), so we could try chopping the TS into chunks to get something working. Alternatively, I wonder if there is a way to skip certain edges if they create particularly problematic constraints, and output the list of skipped edges separately? This could be useful to identify badly inferred relationships. For this I suggest that it would be much better to omit bad edges than bad nodes, if possible.

@nspope
Copy link
Contributor Author

nspope commented Nov 4, 2024

I'd skip rescaling for the covid trees -- there's so many constraints (including a fixed root age) that I really doubt that it's necessary (in unconstrained ARGs it's generally most helpful towards the root).

Regarding the assertion -- I can do something hacky to bypass the issue, if you want to give it a go on the full ARG this or next week (and do it properly later).

@nspope
Copy link
Contributor Author

nspope commented Nov 4, 2024

With last commit it'll run through on the 450K covid example you sent @hyanwong -- can you try it on a larger example?

@hyanwong
Copy link
Member

hyanwong commented Nov 17, 2024

At the end of the dating process, we need to pick dates for the nodes (and mutations) that are consistent with the parent-older-than-child requirement of a tree sequence. There are two parts to this, a least squares and a simple upwards iteration. However both of these can cause samples at non-zero times to have their times changed, which is not ideal.

A relatively simple fix would be to rescale all the node times at the end of the routine. If we have a fixed node N2 at time t2, closest in time to a younger fixed node N1 at times t1, but we estimate its time as t2*, we simply need to multiply the times of non-fixed nodes at times between t1 and t2* by (t2-t1)/(t2*-t1), and do so iteratively up the ARG from the tips.

@nspope
Copy link
Contributor Author

nspope commented Nov 17, 2024

The least squares approach will respect sample ages, but can't get the constraint exactly. So we always have to recourse to the simple approach of bumping parents up, which is where the problem lies. Post-hoc rescaling like you suggest is a nice solution.

@hyanwong
Copy link
Member

hyanwong commented Nov 17, 2024

Here's an (currently failing) algorithm for reselling using a single pass through the nodes in time order. It requires use to ensure that all sample nodes that occur at the same time are given the same inferred time (otherwise we could "squash" children into a zero-length timeslice, and have child_time == parent_time). We should be able to enforce this same-time requirement using the iterative approach, but it is not coded up below, so this is not currently working.

# TODO: use the iterative approch to ensure that all sample nodes that occur at
# the same time are given the same inferred time, and in the same order

# tmp_ts is the currently dated ts. orig_ts is the ts with the true sample dates.

tmp_sample_times = np.unique(tmp_ts.nodes_time[node_is_sample])

if len(tmp_sample_times) > 1:
    orig_sample_times = orig_ts.nodes_time[node_is_sample]
    tmp_times = tmp_ts.nodes_time
    new_times = np.zeros_like(ts.nodes_time)
    idx = 0
    rescaling = None
    for u in np.lexsort((np.arange(tmp_ts.num_nodes), tmp_ts.nodes_time)):
        time_u = tmp_times[u]
        if rescaling is None and not node_is_sample[u]:
            pass  # Skip until the youngest sample node is found
        elif time_u == tmp_sample_times[idx]:
            # Same time as the sample
            new_times[u] = orig_sample_times[idx]
        else:
            tmp_diff = time_u - tmp_sample_times[idx]
            if idx == len(orig_sample_times) - 1:
                # top of the ARG, just subtract a constant
                new_times[u] = orig_sample_times[idx] + tmp_diff
            else:
                while idx < len(orig_sample_times) - 1 and time_u >= tmp_sample_times[idx + 1]:
                    # new timeslice
                    idx += 1
                    tmp_diff = time_u - tmp_sample_times[idx]
                    rescaling = (orig_sample_times[idx + 1] - orig_sample_times[idx]) / (tmp_sample_times[idx + 1] - tmp_sample_times[idx])
                new_times[u] = orig_sample_times[idx] + tmp_diff * rescaling
        
tables = tmp_ts.dump_tables()
tables.nodes.time = new_times
new_ts = tables.tree_sequence()

@hyanwong
Copy link
Member

It looks like you haven't switched to the ruff linter for your tsdate install @nspope, so CI isn't passing the linting tests.

@nspope nspope force-pushed the noncontemporary-samples branch from a142dec to 1791b09 Compare December 5, 2024 17:29
@hyanwong
Copy link
Member

I'd like to merge this @nspope . Is there any reason not to? For example, I see # FIXME: this is a hacky bypass in one place (but I think that's fine, even if we release this code as a new version, right)

@hyanwong
Copy link
Member

@Mergifyio rebase

Copy link
Contributor

mergify bot commented Jan 16, 2025

rebase

✅ Branch has been successfully rebased

@hyanwong hyanwong force-pushed the noncontemporary-samples branch from 362aaa5 to 4e0a7b6 Compare January 16, 2025 16:48
@nspope
Copy link
Contributor Author

nspope commented Jan 16, 2025

Don't merge it, because it's doing the wrong thing (resetting messages for one part of the algorithm, but not for other parts). That'll be fine for sc2ts where there's no prior (roots are fixed) and no unphased singletons. Otherwise the updates will get progressively more incorrect if the "hacky bypass" bit is triggered. Fixing this requires a small refactor, unfortunately.

@hyanwong
Copy link
Member

Ah, I hadn't realised that (or had forgotten). Thanks for the explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants