diff --git a/tests/test_cli.py b/tests/test_cli.py index 2d1d5d5e..05183f8c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -232,10 +232,9 @@ def test_no_output_variational_gamma(self, tmp_path, capfd): @pytest.mark.parametrize(("flag", "log_status"), logging_flags.items()) def test_verbosity(self, tmp_path, caplog, flag, log_status): - popsize = 10000 ts = msprime.simulate( 10, - Ne=popsize, + Ne=10000, mutation_rate=1e-8, recombination_rate=1e-8, length=2e4, @@ -246,7 +245,9 @@ def test_verbosity(self, tmp_path, caplog, flag, log_status): caplog.set_level(getattr(logging, log_status)) # either tsdate preprocess or tsdate date (in_out method has debug asserts) self.run_tsdate_cli(tmp_path, ts, flag, cmd="preprocess") - self.run_tsdate_cli(tmp_path, ts, f"-n 10 --method inside_outside {flag}") + self.run_tsdate_cli( + tmp_path, ts, f"--mutation-rate 1e-8 --rescaling-intervals 0 {flag}" + ) assert log_status in caplog.text @pytest.mark.parametrize( diff --git a/tests/test_functions.py b/tests/test_functions.py index 5d54731a..30ad9a80 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1551,14 +1551,14 @@ def test_node_metadata_inside_outside(self): ts = msprime.simulate( 10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12 ) - algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=10000) + algorithm = InsideOutsideMethod(ts, mutation_rate=1, population_size=10000) mn_post, *_ = algorithm.run( eps=1e-6, outside_standardize=True, ignore_oldest_root=False, probability_space=tsdate.base.LOG, ) - dts = tsdate.inside_outside(ts, population_size=10000, mutation_rate=None) + dts = tsdate.inside_outside(ts, population_size=10000, mutation_rate=1) unconstr_mn = [nd.metadata["mn"] for nd in dts.nodes() if "mn" in nd.metadata] assert np.allclose(unconstr_mn, mn_post) assert np.all(dts.tables.nodes.time >= mn_post) @@ -1825,8 +1825,8 @@ def test_node_selection_param(self): def test_sites_time_insideoutside(self): ts = utility_functions.two_tree_mutation_ts() - dated = tsdate.inside_outside(ts, mutation_rate=None, population_size=1) - algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=1) + dated = tsdate.inside_outside(ts, mutation_rate=1, population_size=1) + algorithm = InsideOutsideMethod(ts, mutation_rate=1, population_size=1) mn_post, *_ = algorithm.run( eps=1e-6, outside_standardize=True, @@ -1931,14 +1931,14 @@ def test_sites_time_simulated_inside_outside(self): ts = msprime.simulate( 10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12 ) - algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=10000) + algorithm = InsideOutsideMethod(ts, mutation_rate=1, population_size=10000) mn_post, *_ = algorithm.run( eps=1e-6, outside_standardize=True, ignore_oldest_root=False, probability_space=tsdate.base.LOG, ) - dts = tsdate.inside_outside(ts, mutation_rate=None, population_size=10000) + dts = tsdate.inside_outside(ts, mutation_rate=1, population_size=10000) assert np.allclose( mn_post[ts.tables.mutations.node], tsdate.sites_time_from_ts(dts, unconstrained=True, min_time=0), diff --git a/tests/test_inference.py b/tests/test_inference.py index cd53105e..200519ca 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -65,11 +65,25 @@ def test_no_population_size(self): tsdate.inside_outside(ts, mutation_rate=None) def test_no_mutation(self): - ts = utility_functions.two_tree_mutation_ts() - with pytest.raises(ValueError, match="method requires mutation rate"): - tsdate.date(ts, method="maximization", population_size=1, mutation_rate=None) - with pytest.raises(ValueError, match="method requires mutation rate"): - tsdate.date(ts, method="variational_gamma", mutation_rate=None) + for ts in ( + utility_functions.two_tree_mutation_ts(), + utility_functions.single_tree_ts_mutation_n3(), + ): + with pytest.raises(ValueError, match="method requires mutation rate"): + tsdate.date( + ts, method="maximization", population_size=1, mutation_rate=None + ) + with pytest.raises(ValueError, match="method requires mutation rate"): + tsdate.date(ts, method="variational_gamma", mutation_rate=None) + if ts.num_trees > 1: + with pytest.raises(NotImplementedError, match="more than one tree"): + tsdate.date( + ts, method="inside_outside", population_size=1, mutation_rate=None + ) + else: + tsdate.date( + ts, method="inside_outside", population_size=1, mutation_rate=None + ) def test_not_needed_population_size(self): ts = utility_functions.two_tree_mutation_ts() @@ -87,9 +101,9 @@ def test_both_ne_and_population_size_specified(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="Only provide one of Ne"): tsdate.inside_outside( - ts, mutation_rate=None, population_size=PopulationSizeHistory(1), Ne=1 + ts, mutation_rate=1, population_size=PopulationSizeHistory(1), Ne=1 ) - tsdate.inside_outside(ts, mutation_rate=None, Ne=PopulationSizeHistory(1)) + tsdate.inside_outside(ts, mutation_rate=1, Ne=PopulationSizeHistory(1)) def test_inside_outside_dangling_failure(self): ts = utility_functions.single_tree_ts_n2_dangling() @@ -152,7 +166,7 @@ def test_no_posteriors(self): def test_discretised_posteriors(self): ts = utility_functions.two_tree_mutation_ts() ts, posteriors = tsdate.inside_outside( - ts, mutation_rate=None, population_size=1, return_posteriors=True + ts, mutation_rate=1, population_size=1, return_posteriors=True ) assert len(posteriors) == ts.num_nodes - ts.num_samples + 1 assert len(posteriors["time"]) > 0 @@ -180,13 +194,13 @@ def test_marginal_likelihood(self): ts = utility_functions.two_tree_mutation_ts() _, _, marg_lik = tsdate.inside_outside( ts, - mutation_rate=None, + mutation_rate=1, population_size=1, return_posteriors=True, return_likelihood=True, ) _, marg_lik_again = tsdate.inside_outside( - ts, mutation_rate=None, population_size=1, return_likelihood=True + ts, mutation_rate=1, population_size=1, return_likelihood=True ) assert marg_lik == marg_lik_again @@ -194,20 +208,12 @@ def test_intervals(self): ts = utility_functions.two_tree_ts() long_ts = utility_functions.two_tree_ts_extra_length() keep_ts = long_ts.keep_intervals([[0.0, 1.0]]) - delete_ts = long_ts.delete_intervals([[1.0, 1.5]]) - dated_ts = tsdate.inside_outside(ts, mutation_rate=None, population_size=1) - dated_keep_ts = tsdate.inside_outside( - keep_ts, mutation_rate=None, population_size=1 - ) - dated_deleted_ts = tsdate.inside_outside( - delete_ts, mutation_rate=None, population_size=1 - ) - assert np.allclose( - dated_ts.tables.nodes.time[:], dated_keep_ts.tables.nodes.time[:] - ) - assert np.allclose( - dated_ts.tables.nodes.time[:], dated_deleted_ts.tables.nodes.time[:] - ) + del_ts = long_ts.delete_intervals([[1.0, 1.5]]) + dat_ts = tsdate.inside_outside(ts, mutation_rate=1, population_size=1) + dat_keep_ts = tsdate.inside_outside(keep_ts, mutation_rate=1, population_size=1) + dat_del_ts = tsdate.inside_outside(del_ts, mutation_rate=1, population_size=1) + assert np.allclose(dat_ts.tables.nodes.time[:], dat_keep_ts.tables.nodes.time[:]) + assert np.allclose(dat_ts.tables.nodes.time[:], dat_del_ts.tables.nodes.time[:]) class TestSimulated: diff --git a/tsdate/cli.py b/tsdate/cli.py index 541ffd13..6cf3b792 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -155,7 +155,7 @@ def tsdate_cli_parser(): ) parser.add_argument( "--rescaling-intervals", - type=float, + type=int, help=( "The number of time intervals within which to estimate a time scaling " f"parameter. Default: None treated as {core.DEFAULT_RESCALING_INTERVALS}" diff --git a/tsdate/core.py b/tsdate/core.py index c3ae301c..aebc3f92 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -21,7 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """ -Infer the age of nodes conditional on a tree sequence topology. +Infer the age of nodes from mutational data, conditional on a tree sequence topology. """ import logging @@ -331,6 +331,15 @@ def run( num_threads=None, cache_inside=None, ): + if self.mutation_rate is None and self.recombination_rate is None: + if self.ts.num_trees > 1: + raise NotImplementedError( + "Specifying no mutation or recombination rate implies dating using " + "the topology-only clock. This produces biased results under " + "recombination (https://github.com/tskit-dev/tsdate/issues/292). " + "The topology-only clock has therefore been deprecated for tree " + "sequences representing more than one tree." + ) if self.provenance_params is not None: self.provenance_params.update( {k: v for k, v in locals().items() if k != "self"} @@ -375,7 +384,7 @@ def run( num_threads=None, cache_inside=None, ): - if self.mutation_rate is None: + if self.mutation_rate is None and self.recombination_rate is None: raise ValueError("Outside maximization method requires mutation rate") if self.provenance_params is not None: self.provenance_params.update( @@ -864,10 +873,9 @@ def date( Infer dates for nodes in a genealogical graph (or :ref:`ARG`) stored in the :ref:`succinct tree sequence` format. New times are assigned to nodes using the estimation algorithm specified by - ``method`` (see note below). If a ``mutation_rate`` is given, - the mutation clock is used. The recombination clock is unsupported at this - time. If neither a ``mutation_rate`` nor a ``recombination_rate`` is given, a - topology-only clock is used. Times associated with mutations and times associated + ``method`` (see note below). A ``mutation_rate`` must be given (the recombination_rate + parameter, implementing a recombination clock, is unsupported at this + time). Times associated with mutations and times associated with non-fixed (non-sample) nodes are overwritten. For example: .. code-block:: python diff --git a/tsdate/variational.py b/tsdate/variational.py index 0aa51910..78ded664 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -654,7 +654,7 @@ def iterate( regularise=True, check_valid=False, # for debugging ): - # pass through singleton blocks + logger.debug("Passing through singleton blocks") self.propagate_likelihood( self.block_order, self.block_nodes[ROOTWARD], @@ -670,7 +670,7 @@ def iterate( USE_BLOCK_LIKELIHOOD, ) - # rootward + leafward pass through edges + logger.debug("Rootward + leafward pass through edges") self.propagate_likelihood( self.edge_order, self.edge_parents, @@ -686,8 +686,8 @@ def iterate( USE_EDGE_LIKELIHOOD, ) - # exponential regularization on roots if regularise: + logger.debug("Exponential regularization on roots") self.propagate_prior( self.roots, self.node_posterior, @@ -698,7 +698,7 @@ def iterate( em_reltol, ) - # absorb the scaling term into the factors + logger.debug("Absorbing scaling term into the factors") self.rescale_factors( self.edge_parents, self.edge_children, @@ -799,6 +799,7 @@ def run( muts_timing = time.time() mutations_phased = self.mutation_blocks == tskit.NULL + logger.debug("Passing through unphased singletons") self.propagate_mutations( # unphased singletons self.mutation_order[~mutations_phased], self.mutation_posterior, @@ -813,6 +814,7 @@ def run( self.node_scale, USE_BLOCK_LIKELIHOOD, ) + logger.debug("Passing through phased mutations") self.propagate_mutations( # phased mutations self.mutation_order[mutations_phased], self.mutation_posterior,