From b0ea6a1a43792a73001377097e3ef5ee9a09688a Mon Sep 17 00:00:00 2001 From: awohns Date: Wed, 29 Jun 2022 13:11:04 -0400 Subject: [PATCH] Allow ancient samples Rework build-prior and inside / outside logic to allow historical samples And speed up time constraint algorithms while also allowing nodes to be out of time order --- CHANGELOG.rst | 8 ++ requirements.txt | 2 +- tests/test_functions.py | 17 +--- tests/test_inference.py | 19 +++- tsdate/base.py | 16 +++- tsdate/core.py | 119 +++++++++++++++---------- tsdate/prior.py | 187 +++++++++++++++++++++++++++++++++------- 7 files changed, 269 insertions(+), 99 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6d44bf32..4cd74157 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,14 @@ [0.1.6] - ****-**-** -------------------- +**Features** + +- Historical samples can now be incorporated directly into the dating framework. + This is done by constructing a bespoke prior grid using + ``grid=tsdate.build_prior_grid(..., allow_historical_samples=True`` and + passing that into ``tsdate.date``. It is also possible to set a variance for + historial sample nodes. + **Breaking changes** - The standalone ``preprocess_ts`` function now defaults to not removing unreferenced diff --git a/requirements.txt b/requirements.txt index e54c32cd..1e0cbccc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -tskit>=0.4.0 +tskit>=0.5.2 tsinfer>=0.3.0 flake8 numpy diff --git a/tests/test_functions.py b/tests/test_functions.py index fdcc6131..bd0cd374 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1048,7 +1048,7 @@ def test_dangling_fails(self): print(ts.draw_text()) print("Samples:", ts.samples()) Ne = 0.5 - with pytest.raises(ValueError, match="simplified"): + with pytest.raises(ValueError, match="simplify"): tsdate.build_prior_grid(ts, Ne, timepoints=np.array([0, 1.2, 2])) # mut_rate = 1 # eps = 1e-6 @@ -1421,7 +1421,7 @@ def test_date_input(self): def test_sample_as_parent_fails(self): ts = utility_functions.single_tree_ts_n3_sample_as_parent() - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError, match="samples at non-zero times"): tsdate.date(ts, mutation_rate=None, Ne=1) def test_recombination_not_implemented(self): @@ -1532,18 +1532,7 @@ def test_constrain_ages_topo(self): ts = utility_functions.two_tree_ts() post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0]) eps = 1e-6 - nodes_to_date = np.array([3, 4, 5]) - constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date) - assert np.array_equal( - np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages - ) - - def test_constrain_ages_topo_no_nodes_to_date(self): - ts = utility_functions.two_tree_ts() - post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0]) - eps = 1e-6 - nodes_to_date = None - constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date) + constrained_ages = constrain_ages_topo(ts, post_mn, eps) assert np.array_equal( np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages ) diff --git a/tests/test_inference.py b/tests/test_inference.py index 59d5ff32..1660d6f9 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -61,7 +61,7 @@ def test_bad_Ne(self): def test_dangling_failure(self): ts = utility_functions.single_tree_ts_n2_dangling() - with pytest.raises(ValueError, match="simplified"): + with pytest.raises(ValueError, match="simplify"): tsdate.date(ts, mutation_rate=None, Ne=1) def test_unary_failure(self): @@ -271,7 +271,7 @@ def test_fails_multi_root(self): with pytest.raises(ValueError): tsdate.date(multiroot_ts, Ne=1, mutation_rate=2, priors=good_priors) - def test_non_contemporaneous(self): + def test_non_contemporaneous_warn(self): samples = [ msprime.Sample(population=0, time=0), msprime.Sample(population=0, time=0), @@ -279,8 +279,21 @@ def test_non_contemporaneous(self): msprime.Sample(population=0, time=1.0), ] ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12) - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError, match="samples at non-zero times"): tsdate.date(ts, Ne=1, mutation_rate=2) + with pytest.raises(ValueError, match="samples at non-zero times"): + tsdate.build_prior_grid(ts, Ne=1) + + def test_non_contemporaneous(self): + samples = [ + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=1.0), + ] + ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12) + priors = tsdate.build_prior_grid(ts, Ne=1, allow_historical_samples=True) + tsdate.date(ts, priors=priors, mutation_rate=2) def test_no_mutation_times(self): ts = msprime.simulate(20, Ne=1, mutation_rate=1, random_seed=12) diff --git a/tsdate/base.py b/tsdate/base.py index d4130606..007277a9 100644 --- a/tsdate/base.py +++ b/tsdate/base.py @@ -95,6 +95,12 @@ def __init__( ] = (-np.arange(num_nodes - self.num_nonfixed) - 1) self.probability_space = LIN + def fixed_node_ids(self): + return np.where(self.row_lookup < 0)[0] + + def nonfixed_node_ids(self): + return np.where(self.row_lookup >= 0)[0] + def force_probability_space(self, probability_space): """ probability_space can be "logarithmic" or "linear": this function will force @@ -140,6 +146,9 @@ def normalize(self): else: raise RuntimeError("Probability space is not", LIN, "or", LOG) + def is_fixed(self, node_id): + return self.row_lookup[node_id] < 0 + def __getitem__(self, node_id): index = self.row_lookup[node_id] if index < 0: @@ -207,8 +216,7 @@ def fill_fixed(orig, fixed_data): new_obj.fixed_data = fill_fixed( self, grid_data if fixed_data is None else fixed_data ) - if probability_space is None: - new_obj.probability_space = self.probability_space - else: - new_obj.probability_space = probability_space + new_obj.probability_space = self.probability_space + if probability_space is not None: + new_obj.force_probability_space(probability_space) return new_obj diff --git a/tsdate/core.py b/tsdate/core.py index c9ba97bf..7418c054 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True): """ ll = scipy.stats.poisson.pmf(muts, dt * mutation_rate * span) if normalize: - return ll / np.max(ll) + return ll / np.nanmax(ll) else: return ll @@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge): mutations_on_edge = self.mut_edges[edge.id] child_time = self.ts.node(edge.child).time - assert child_time == 0 - # Temporary hack - we should really take a more precise likelihood - return self._lik( - mutations_on_edge, - edge.span, - self.timediff, - self.mut_rate, - normalize=self.normalize, - ) + if child_time == 0: + return self._lik( + mutations_on_edge, + edge.span, + self.timediff, + self.mut_rate, + normalize=self.normalize, + ) + else: + timediff = self.timepoints - child_time + 1e-8 + # Temporary hack - we should really take a more precise likelihood + likelihood = self._lik( + mutations_on_edge, + edge.span, + timediff, + self.mut_rate, + normalize=self.normalize, + ) + # Prevent child from being older than parent + likelihood[timediff < 0] = 0 + + return likelihood def get_mut_lik_lower_tri(self, edge): """ @@ -389,7 +402,7 @@ def get_fixed(self, arr, edge): return arr * liks def scale_geometric(self, fraction, value): - return value**fraction + return value ** fraction class LogLikelihoods(Likelihoods): @@ -429,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True): """ ll = scipy.stats.poisson.logpmf(muts, dt * mutation_rate * span) if normalize: - return ll - np.max(ll) + return ll - np.nanmax(ll) else: return ll @@ -634,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None): inside = self.priors.clone_with_new_data( # store inside matrix values grid_data=np.nan, fixed_data=self.lik.identity_constant ) + # It is possible that a simple node is non-fixed, in which case we want to + # provide an inside array that reflects the prior distribution + nonfixed_samples = np.intersect1d(inside.nonfixed_node_ids(), self.ts.samples()) + for u in nonfixed_samples: + # this is in the same probability space as the prior, so we should be + # OK just to copy the prior values straight in. It's unclear to me (Yan) + # how/if they should be normalised, however + inside[u][:] = self.priors[u] + if cache_inside: g_i = np.full( (self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant ) norm = np.full(self.ts.num_nodes, np.nan) + to_visit = np.zeros(self.ts.num_nodes, dtype=bool) + to_visit[inside.nonfixed_node_ids()] = True # Iterate through the nodes via groupby on parent node for parent, edges in tqdm( self.edges_by_parent_asc(), @@ -673,14 +697,23 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None): "dangling nodes: please simplify it" ) daughter_val = self.lik.scale_geometric( - spanfrac, self.lik.make_lower_tri(inside[edge.child]) + spanfrac, self.lik.make_lower_tri(inside_values) ) edge_lik = self.lik.get_inside(daughter_val, edge) val = self.lik.combine(val, edge_lik) + if np.all(val == 0): + raise ValueError if cache_inside: g_i[edge.id] = edge_lik - norm[parent] = np.max(val) if normalize else 1 + norm[parent] = np.max(val) if normalize else self.lik.identity_constant inside[parent] = self.lik.reduce(val, norm[parent]) + to_visit[parent] = False + + # There may be nodes that are not parents but are also not fixed (e.g. + # undated sample nodes). These need an identity normalization constant + for unfixed_unvisited in np.where(to_visit)[0]: + norm[unfixed_unvisited] = self.lik.identity_constant + if cache_inside: self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None]) # Keep the results in this object @@ -897,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None): return ts, mn_post, vr_post -def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False): +def constrain_ages_topo(ts, node_times, eps, progress=False): """ - If predicted node times violate topology, restrict node ages so that they - must be older than all their children. + If node_times violate topology, return increased node_times so that each node is + guaranteed to be older than any of its their children. """ - new_mn_post = np.copy(post_mn) - if nodes_to_date is None: - nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64) - nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())] - - tables = ts.tables - parents = tables.edges.parent - nd_children = tables.edges.child[np.argsort(parents)] - parents = sorted(parents) - parents_unique = np.unique(parents, return_index=True) - parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)] - for index, nd in tqdm( - enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress + edges_parent = ts.edges_parent + edges_child = ts.edges_child + + new_node_times = np.copy(node_times) + # Traverse through the ARG, ensuring children come before parents. + # This can be done by iterating over groups of edges with the same parent + new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1 + for edges_start, edges_end in tqdm( + zip( + itertools.chain([0], new_parent_edge_idx), + itertools.chain(new_parent_edge_idx, [len(edges_parent)]), + ), + desc="Constrain Ages", + disable=not progress, ): - if index + 1 != len(nodes_to_date): - children_index = np.arange(parent_indices[index], parent_indices[index + 1]) - else: - children_index = np.arange(parent_indices[index], ts.num_edges) - children = nd_children[children_index] - time = np.max(new_mn_post[children]) - if new_mn_post[nd] <= time: - new_mn_post[nd] = time + eps - return new_mn_post + parent = edges_parent[edges_start] + child_ids = edges_child[edges_start:edges_end] # May contain dups + oldest_child_time = np.max(new_node_times[child_ids]) + if oldest_child_time >= new_node_times[parent]: + new_node_times[parent] = oldest_child_time + eps + return new_node_times def date( @@ -1015,7 +1046,7 @@ def date( progress=progress, **kwargs ) - constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress) + constrained = constrain_ages_topo(tree_sequence, dates, eps, progress) tables = tree_sequence.dump_tables() tables.time_units = time_units tables.nodes.time = constrained @@ -1064,12 +1095,6 @@ def get_dates( :return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date) """ - # Stuff yet to be implemented. These can be deleted once fixed - for sample in tree_sequence.samples(): - if tree_sequence.node(sample).time != 0: - raise NotImplementedError("Samples must all be at time 0") - fixed_nodes = set(tree_sequence.samples()) - # Default to not creating approximate priors unless ts has > 1000 samples approx_priors = False if tree_sequence.num_samples > 1000: @@ -1097,6 +1122,8 @@ def get_dates( ) priors = priors + fixed_nodes = set(priors.fixed_node_ids()) + if probability_space != base.LOG: liklhd = Likelihoods( tree_sequence, diff --git a/tsdate/prior.py b/tsdate/prior.py index 2c5b04c7..e48d6247 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -23,6 +23,7 @@ """ Routines and classes for creating priors and timeslices for use in tsdate """ +import itertools import logging import os from collections import defaultdict @@ -426,10 +427,10 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False): self.ts = tree_sequence self.sample_node_set = set(self.ts.samples()) - if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0): - raise ValueError( - "The SpansBySamples class needs a tree seq with all samples at time 0" - ) + # if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0): + # raise ValueError( + # "The SpansBySamples class needs a tree seq with all samples at time 0" + # ) self.progress = progress # We will store the spans in here, and normalize them at the end @@ -947,29 +948,59 @@ def gamma_cdf(t_set, alpha, beta): return np.insert(t_set, 0, 0) -def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=False): +def fill_priors( + node_parameters, + timepoints, + ts, + Ne, + *, + prior_distr, + nonfixed_sample_var=None, + progress=False, +): """ Take the alpha and beta values from the node_parameters array, which contains - one row for each node in the TS (including fixed nodes) - and fill out a NodeGridValues object with the prior values from the - gamma or lognormal distribution with those parameters. + one row for each node in the TS (including fixed nodes, although alpha and beta + are ignored for these nodes) and fill out a NodeGridValues object with the prior + values from the gamma or lognormal distribution with those parameters. + + For a description of `nonfixed_sample_var`, see the parameter description in + the `build_grid` function. TODO - what if there is an internal fixed node? Should we truncate """ if prior_distr == "lognorm": cdf_func = scipy.stats.lognorm.cdf - main_param = np.sqrt(node_parameters[:, PriorParams.field_index("beta")]) + shape_param = np.sqrt(node_parameters[:, PriorParams.field_index("beta")]) scale_param = np.exp(node_parameters[:, PriorParams.field_index("alpha")]) + + def shape_scale_from_mean_var(mean, var): + a, b = lognorm_approx(mean, var) + return np.sqrt(b), np.exp(a) + elif prior_distr == "gamma": cdf_func = scipy.stats.gamma.cdf - main_param = node_parameters[:, PriorParams.field_index("alpha")] - scale_param = 1 / node_parameters[:, PriorParams.field_index("beta")] + shape_param = node_parameters[:, PriorParams.field_index("alpha")] + scale_param = 1.0 / node_parameters[:, PriorParams.field_index("beta")] + + def shape_scale_from_mean_var(mean, var): + a, b = gamma_approx(mean, var) + return a, 1.0 / b + else: raise ValueError("prior distribution must be lognorm or gamma") - + samples = ts.samples() + if nonfixed_sample_var is None: + nonfixed_sample_var = {} + for u in nonfixed_sample_var.keys(): + if u not in samples: + raise ValueError(f"Node {u} in 'nonfixed_sample_var' is not a sample") datable_nodes = np.ones(ts.num_nodes, dtype=bool) - datable_nodes[ts.samples()] = False + datable_nodes[samples] = False + # Mark all nodes in nonfixed_sample_var as datable + datable_nodes[list(nonfixed_sample_var.keys())] = True datable_nodes = np.where(datable_nodes)[0] + prior_times = base.NodeGridValues( ts.num_nodes, datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32), @@ -980,8 +1011,16 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa for node in tqdm( datable_nodes, desc="Assign Prior to Each Node", disable=not progress ): + if node in nonfixed_sample_var: + shape, scale = shape_scale_from_mean_var( + mean=ts.node(node).time, + var=nonfixed_sample_var[node], + ) + else: + shape = shape_param[node] + scale = scale_param[node] with np.errstate(divide="ignore", invalid="ignore"): - prior_node = cdf_func(timepoints, main_param[node], scale=scale_param[node]) + prior_node = cdf_func(timepoints, shape, scale=scale) # force age to be less than max value prior_node = np.divide(prior_node, np.max(prior_node)) # prior in each epoch @@ -991,6 +1030,63 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa return prior_times +def _truncate_priors(ts, priors, progress=False): + """ + Truncate priors for all nonfixed nodes + so they conform to the age of fixed nodes in the tree sequence + """ + fixed_nodes = priors.fixed_node_ids() + fixed_times = ts.nodes_time[fixed_nodes] + + grid_data = np.copy(priors.grid_data[:]) + timepoints = priors.timepoints + if np.max(fixed_times) >= np.max(timepoints): + raise ValueError("Fixed node times cannot be older than the oldest timepoint") + if priors.probability_space == "linear": + zero_value = 0 + elif priors.probability_space == "logarithmic": + zero_value = -np.inf + constrained_min_times = np.zeros_like(ts.nodes_time) + # Set the min times of fixed nodes to those in the tree sequence + constrained_min_times[fixed_nodes] = fixed_times + + # Traverse through the ARG, ensuring children come before parents. + # This can be done by iterating over groups of edges with the same parent + edges_parent = ts.edges_parent + edges_child = ts.edges_child + new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1 + for edges_start, edges_end in tqdm( + zip( + itertools.chain([0], new_parent_edge_idx), + itertools.chain(new_parent_edge_idx, [len(edges_parent)]), + ), + desc="Trunc priors", + disable=not progress, + ): + parent = edges_parent[edges_start] + child_ids = edges_child[edges_start:edges_end] # May contain dups + oldest_child_time = np.max(constrained_min_times[child_ids]) + if oldest_child_time > constrained_min_times[parent]: + if priors.is_fixed(parent): + raise ValueError( + "Invalid fixed times: time for" + + f"fixed node {parent} is younger than some of its descendants" + ) + constrained_min_times[parent] = oldest_child_time + if constrained_min_times[parent] > 0 and not priors.is_fixed(parent): + nearest_time = np.argmin(np.abs(timepoints - constrained_min_times[parent])) + grid_data[priors.row_lookup[parent]][:nearest_time] = zero_value + + rowmax = grid_data[:, 1:].max(axis=1) + if priors.probability_space == "linear": + grid_data = grid_data / rowmax[:, np.newaxis] + elif priors.probability_space == "logarithmic": + grid_data = grid_data - rowmax[:, np.newaxis] + + priors.grid_data[:] = grid_data + return priors + + def build_grid( tree_sequence, Ne, @@ -999,10 +1095,12 @@ def build_grid( approximate_priors=False, approx_prior_size=None, prior_distribution="lognorm", + allow_historical_samples=None, + truncate_priors=None, + nonfixed_sample_var=None, eps=1e-6, # Parameters below undocumented progress=False, - allow_unary=False, ): """ Using the conditional coalescent, calculate the prior distribution for the age of @@ -1022,17 +1120,34 @@ def build_grid( :param int approx_prior_size: Number of samples from which to precalculate prior. Should only enter value if approximate_priors=True. If approximate_priors=True and no value specified, defaults to 1000. Default: None - :param string prior_distr: What distribution to use to approximate the conditional - coalescent prior. Can be "lognorm" for the lognormal distribution (generally a - better fit, but slightly slower to calculate) or "gamma" for the gamma - distribution (slightly faster, but a poorer fit for recent nodes). Default: - "lognorm" + :param string prior_distribution: What distribution to use to approximate the + conditional coalescent prior. Can be "lognorm" for the lognormal distribution + (generally a better fit, but slightly slower to calculate) or "gamma" for the + gamma distribution (slightly faster, but a poorer fit for recent nodes). + Default: "lognorm" + :param bool allow_historical_samples: should we allow historical samples (i.e. at + times > 0). This invalidates the assumptions of the conditional coalescent, but + may be acceptable if the historical samples are recent or if there are many + contemporaneous samples. Default: ``False`` + :param bool truncate_priors: If there are historical samples, should we truncate the + priors of all nodes which are their ancestors so that the probability of being + younger than the oldest descendant sample is zero. As long as historical + samples do not have ancestors that have been misassigned in the tree sequence + topology, this should give better results. Default: ``True`` + :param dict nonfixed_sample_var: is a dict mapping sample node IDs to a variance + value. Any nodes listed here will be treated as non-fixed nodes whose prior is + not calculated from the conditional coalescent but instead are allocated a prior + whose mean is the node time in the tree sequence and whose variance is the + value in this dictionary. This allows sample nodes to be treated as nonfixed + nodes, and therefore dated. If ``None`` (default) then all sample nodes are + treated as occurring at a fixed time (as if this were an empty dict). :param float eps: Specify minimum distance separating points in the time grid. Also specifies the error factor in time difference calculations. Default: 1e-6 :return: A prior object to pass to tsdate.date() containing prior values for inference and a discretised time grid :rtype: base.NodeGridValues Object """ + if Ne <= 0: raise ValueError("Parameter 'Ne' must be greater than 0") if approximate_priors: @@ -1043,20 +1158,18 @@ def build_grid( raise ValueError( "Can't set approx_prior_size if approximate_prior is False" ) + if truncate_priors is None: + truncate_priors = True + if allow_historical_samples is None: + allow_historical_samples = False - contmpr_ts, node_map = util.reduce_to_contemporaneous(tree_sequence) - if contmpr_ts.num_nodes != tree_sequence.num_nodes: - raise ValueError( - "Passed tree sequence is not simplified and/or contains " - "noncontemporaneous samples" - ) - span_data = SpansBySamples(contmpr_ts, progress=progress, allow_unary=allow_unary) + span_data = SpansBySamples(tree_sequence, progress=progress) base_priors = ConditionalCoalescentTimes( approx_prior_size, Ne, prior_distribution, progress=progress ) - base_priors.add(contmpr_ts.num_samples, approximate_priors) + base_priors.add(tree_sequence.num_samples, approximate_priors) for total_fixed in span_data.total_fixed_at_0_counts: # For missing data: trees vary in total fixed node count => have different priors if total_fixed > 0: @@ -1080,9 +1193,7 @@ def build_grid( else: raise ValueError("time_slices must be an integer or a numpy array of floats") - prior_params_contmpr = base_priors.get_mixture_prior_params(span_data) - # Map the nodes in the prior params back to the node ids in the original ts - prior_params = prior_params_contmpr[node_map, :] + prior_params = base_priors.get_mixture_prior_params(span_data) # Set all fixed nodes (i.e. samples) to have 0 variance priors = fill_priors( prior_params, @@ -1090,6 +1201,20 @@ def build_grid( tree_sequence, Ne, prior_distr=prior_distribution, + nonfixed_sample_var=nonfixed_sample_var, progress=progress, ) + if np.any(tree_sequence.nodes_time[tree_sequence.samples()] > 0): + if not allow_historical_samples: + raise ValueError( + "There are samples at non-zero times, invalidating the conditional " + "coalescent prior. You can set allow_historical_samples=True to carry " + "on regardless, calculating a prior as if all samples were " + "contemporaneous (reasonable if you only have a few ancient samples)" + ) + if ( + np.any(tree_sequence.nodes_time[priors.fixed_node_ids()] > 0) + and truncate_priors + ): + priors = _truncate_priors(tree_sequence, priors, progress=progress) return priors