Skip to content

Commit

Permalink
Ban topology-only dating
Browse files Browse the repository at this point in the history
Fixes #73 . Also adds some debugging comments to variational_gamma
  • Loading branch information
hyanwong committed Nov 5, 2024
1 parent d379356 commit 573483f
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 44 deletions.
7 changes: 4 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,9 @@ def test_no_output_variational_gamma(self, tmp_path, capfd):

@pytest.mark.parametrize(("flag", "log_status"), logging_flags.items())
def test_verbosity(self, tmp_path, caplog, flag, log_status):
popsize = 10000
ts = msprime.simulate(
10,
Ne=popsize,
Ne=10000,
mutation_rate=1e-8,
recombination_rate=1e-8,
length=2e4,
Expand All @@ -246,7 +245,9 @@ def test_verbosity(self, tmp_path, caplog, flag, log_status):
caplog.set_level(getattr(logging, log_status))
# either tsdate preprocess or tsdate date (in_out method has debug asserts)
self.run_tsdate_cli(tmp_path, ts, flag, cmd="preprocess")
self.run_tsdate_cli(tmp_path, ts, f"-n 10 --method inside_outside {flag}")
self.run_tsdate_cli(
tmp_path, ts, f"--mutation-rate 1e-8 --rescaling-intervals 0 {flag}"
)
assert log_status in caplog.text

@pytest.mark.parametrize(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,14 +1551,14 @@ def test_node_metadata_inside_outside(self):
ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=10000)
algorithm = InsideOutsideMethod(ts, mutation_rate=1, population_size=10000)
mn_post, *_ = algorithm.run(
eps=1e-6,
outside_standardize=True,
ignore_oldest_root=False,
probability_space=tsdate.base.LOG,
)
dts = tsdate.inside_outside(ts, population_size=10000, mutation_rate=None)
dts = tsdate.inside_outside(ts, population_size=10000, mutation_rate=1)
unconstr_mn = [nd.metadata["mn"] for nd in dts.nodes() if "mn" in nd.metadata]
assert np.allclose(unconstr_mn, mn_post)
assert np.all(dts.tables.nodes.time >= mn_post)
Expand Down Expand Up @@ -1825,8 +1825,8 @@ def test_node_selection_param(self):

def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.inside_outside(ts, mutation_rate=None, population_size=1)
algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=1)
dated = tsdate.inside_outside(ts, mutation_rate=1, population_size=1)
algorithm = InsideOutsideMethod(ts, mutation_rate=1, population_size=1)
mn_post, *_ = algorithm.run(
eps=1e-6,
outside_standardize=True,
Expand Down Expand Up @@ -1931,14 +1931,14 @@ def test_sites_time_simulated_inside_outside(self):
ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=10000)
algorithm = InsideOutsideMethod(ts, mutation_rate=1, population_size=10000)
mn_post, *_ = algorithm.run(
eps=1e-6,
outside_standardize=True,
ignore_oldest_root=False,
probability_space=tsdate.base.LOG,
)
dts = tsdate.inside_outside(ts, mutation_rate=None, population_size=10000)
dts = tsdate.inside_outside(ts, mutation_rate=1, population_size=10000)
assert np.allclose(
mn_post[ts.tables.mutations.node],
tsdate.sites_time_from_ts(dts, unconstrained=True, min_time=0),
Expand Down
54 changes: 30 additions & 24 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,25 @@ def test_no_population_size(self):
tsdate.inside_outside(ts, mutation_rate=None)

def test_no_mutation(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="method requires mutation rate"):
tsdate.date(ts, method="maximization", population_size=1, mutation_rate=None)
with pytest.raises(ValueError, match="method requires mutation rate"):
tsdate.date(ts, method="variational_gamma", mutation_rate=None)
for ts in (
utility_functions.two_tree_mutation_ts(),
utility_functions.single_tree_ts_mutation_n3(),
):
with pytest.raises(ValueError, match="method requires mutation rate"):
tsdate.date(
ts, method="maximization", population_size=1, mutation_rate=None
)
with pytest.raises(ValueError, match="method requires mutation rate"):
tsdate.date(ts, method="variational_gamma", mutation_rate=None)
if ts.num_trees > 1:
with pytest.raises(NotImplementedError, match="more than one tree"):
tsdate.date(
ts, method="inside_outside", population_size=1, mutation_rate=None
)
else:
tsdate.date(
ts, method="inside_outside", population_size=1, mutation_rate=None
)

def test_not_needed_population_size(self):
ts = utility_functions.two_tree_mutation_ts()
Expand All @@ -87,9 +101,9 @@ def test_both_ne_and_population_size_specified(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Only provide one of Ne"):
tsdate.inside_outside(
ts, mutation_rate=None, population_size=PopulationSizeHistory(1), Ne=1
ts, mutation_rate=1, population_size=PopulationSizeHistory(1), Ne=1
)
tsdate.inside_outside(ts, mutation_rate=None, Ne=PopulationSizeHistory(1))
tsdate.inside_outside(ts, mutation_rate=1, Ne=PopulationSizeHistory(1))

def test_inside_outside_dangling_failure(self):
ts = utility_functions.single_tree_ts_n2_dangling()
Expand Down Expand Up @@ -152,7 +166,7 @@ def test_no_posteriors(self):
def test_discretised_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts, posteriors = tsdate.inside_outside(
ts, mutation_rate=None, population_size=1, return_posteriors=True
ts, mutation_rate=1, population_size=1, return_posteriors=True
)
assert len(posteriors) == ts.num_nodes - ts.num_samples + 1
assert len(posteriors["time"]) > 0
Expand Down Expand Up @@ -180,34 +194,26 @@ def test_marginal_likelihood(self):
ts = utility_functions.two_tree_mutation_ts()
_, _, marg_lik = tsdate.inside_outside(
ts,
mutation_rate=None,
mutation_rate=1,
population_size=1,
return_posteriors=True,
return_likelihood=True,
)
_, marg_lik_again = tsdate.inside_outside(
ts, mutation_rate=None, population_size=1, return_likelihood=True
ts, mutation_rate=1, population_size=1, return_likelihood=True
)
assert marg_lik == marg_lik_again

def test_intervals(self):
ts = utility_functions.two_tree_ts()
long_ts = utility_functions.two_tree_ts_extra_length()
keep_ts = long_ts.keep_intervals([[0.0, 1.0]])
delete_ts = long_ts.delete_intervals([[1.0, 1.5]])
dated_ts = tsdate.inside_outside(ts, mutation_rate=None, population_size=1)
dated_keep_ts = tsdate.inside_outside(
keep_ts, mutation_rate=None, population_size=1
)
dated_deleted_ts = tsdate.inside_outside(
delete_ts, mutation_rate=None, population_size=1
)
assert np.allclose(
dated_ts.tables.nodes.time[:], dated_keep_ts.tables.nodes.time[:]
)
assert np.allclose(
dated_ts.tables.nodes.time[:], dated_deleted_ts.tables.nodes.time[:]
)
del_ts = long_ts.delete_intervals([[1.0, 1.5]])
dat_ts = tsdate.inside_outside(ts, mutation_rate=1, population_size=1)
dat_keep_ts = tsdate.inside_outside(keep_ts, mutation_rate=1, population_size=1)
dat_del_ts = tsdate.inside_outside(del_ts, mutation_rate=1, population_size=1)
assert np.allclose(dat_ts.tables.nodes.time[:], dat_keep_ts.tables.nodes.time[:])
assert np.allclose(dat_ts.tables.nodes.time[:], dat_del_ts.tables.nodes.time[:])


class TestSimulated:
Expand Down
2 changes: 1 addition & 1 deletion tsdate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def tsdate_cli_parser():
)
parser.add_argument(
"--rescaling-intervals",
type=float,
type=int,
help=(
"The number of time intervals within which to estimate a time scaling "
f"parameter. Default: None treated as {core.DEFAULT_RESCALING_INTERVALS}"
Expand Down
20 changes: 14 additions & 6 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Infer the age of nodes conditional on a tree sequence topology.
Infer the age of nodes from mutational data, conditional on a tree sequence topology.
"""

import logging
Expand Down Expand Up @@ -331,6 +331,15 @@ def run(
num_threads=None,
cache_inside=None,
):
if self.mutation_rate is None and self.recombination_rate is None:
if self.ts.num_trees > 1:
raise NotImplementedError(
"Specifying no mutation or recombination rate implies dating using "
"the topology-only clock. This produces biased results under "
"recombination (https://github.com/tskit-dev/tsdate/issues/292). "
"The topology-only clock has therefore been deprecated for tree "
"sequences representing more than one tree."
)
if self.provenance_params is not None:
self.provenance_params.update(
{k: v for k, v in locals().items() if k != "self"}
Expand Down Expand Up @@ -375,7 +384,7 @@ def run(
num_threads=None,
cache_inside=None,
):
if self.mutation_rate is None:
if self.mutation_rate is None and self.recombination_rate is None:
raise ValueError("Outside maximization method requires mutation rate")
if self.provenance_params is not None:
self.provenance_params.update(
Expand Down Expand Up @@ -864,10 +873,9 @@ def date(
Infer dates for nodes in a genealogical graph (or :ref:`ARG<tutorials:sec_args>`)
stored in the :ref:`succinct tree sequence<tskit:sec_introduction>` format.
New times are assigned to nodes using the estimation algorithm specified by
``method`` (see note below). If a ``mutation_rate`` is given,
the mutation clock is used. The recombination clock is unsupported at this
time. If neither a ``mutation_rate`` nor a ``recombination_rate`` is given, a
topology-only clock is used. Times associated with mutations and times associated
``method`` (see note below). A ``mutation_rate`` must be given (the recombination_rate
parameter, implementing a recombination clock, is unsupported at this
time). Times associated with mutations and times associated
with non-fixed (non-sample) nodes are overwritten. For example:
.. code-block:: python
Expand Down
10 changes: 6 additions & 4 deletions tsdate/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def iterate(
regularise=True,
check_valid=False, # for debugging
):
# pass through singleton blocks
logger.debug("Passing through singleton blocks")
self.propagate_likelihood(
self.block_order,
self.block_nodes[ROOTWARD],
Expand All @@ -670,7 +670,7 @@ def iterate(
USE_BLOCK_LIKELIHOOD,
)

# rootward + leafward pass through edges
logger.debug("Rootward + leafward pass through edges")
self.propagate_likelihood(
self.edge_order,
self.edge_parents,
Expand All @@ -686,8 +686,8 @@ def iterate(
USE_EDGE_LIKELIHOOD,
)

# exponential regularization on roots
if regularise:
logger.debug("Exponential regularization on roots")
self.propagate_prior(
self.roots,
self.node_posterior,
Expand All @@ -698,7 +698,7 @@ def iterate(
em_reltol,
)

# absorb the scaling term into the factors
logger.debug("Absorbing scaling term into the factors")
self.rescale_factors(
self.edge_parents,
self.edge_children,
Expand Down Expand Up @@ -799,6 +799,7 @@ def run(

muts_timing = time.time()
mutations_phased = self.mutation_blocks == tskit.NULL
logger.debug("Passing through unphased singletons")
self.propagate_mutations( # unphased singletons
self.mutation_order[~mutations_phased],
self.mutation_posterior,
Expand All @@ -813,6 +814,7 @@ def run(
self.node_scale,
USE_BLOCK_LIKELIHOOD,
)
logger.debug("Passing through phased mutations")
self.propagate_mutations( # phased mutations
self.mutation_order[mutations_phased],
self.mutation_posterior,
Expand Down

0 comments on commit 573483f

Please sign in to comment.