From ced5d8b771dd32ccb7eb96d73fc8563033089e81 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 8 Aug 2023 11:57:15 -0700 Subject: [PATCH] Mixture prior works More streamlined numerical checks Initialize gamma mixture from conditional coalescent prior Add pdf Update mixture.py to use natural parameterization WIP Moved fully into numba Cleanup Cleanup More debugging WIP Working wording Add missing constant to loglikelihood Skip prior update completely instead of components Skip prior update completely instead of components Remove verbose; use logweights in conditional posterior Move mixture initialization to function Docstrings and CLI Remove some debugging inserts Remove preemptive reference Fix tests --- tests/test_functions.py | 7 -- tests/test_inference.py | 21 ++-- tsdate/cli.py | 31 ++++++ tsdate/core.py | 202 +++++++++++++++++++++++---------- tsdate/mixture.py | 240 ++++++++++++++++++++++++++++++++++++++++ tsdate/prior.py | 2 + 6 files changed, 433 insertions(+), 70 deletions(-) create mode 100644 tsdate/mixture.py diff --git a/tests/test_functions.py b/tests/test_functions.py index 01280a0c..055e4436 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -799,13 +799,6 @@ def test_variational_nosize(self): with pytest.raises(ValueError, match="Must specify population size"): variational_dates(ts, mutation_rate=1) - def test_variational_toomanysizes(self): - ts = utility_functions.two_tree_mutation_ts() - Ne = 1 - priors = tsdate.build_prior_grid(ts, Ne, np.array([0, 1.2, 2])) - with pytest.raises(ValueError, match="Cannot specify"): - variational_dates(ts, mutation_rate=1, population_size=Ne, priors=priors) - class TestNodeGridValuesClass: def test_init(self): diff --git a/tests/test_inference.py b/tests/test_inference.py index 2d31a963..017b473e 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -419,13 +419,13 @@ def test_nonglobal_priors(self): priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") grid = priors.make_parameter_grid(population_size=1) grid.grid_data[:] = [1.0, 0.0] # noninformative prior - tsdate.date( - ts, - mutation_rate=5, - method="variational_gamma", - priors=grid, - global_prior=False, - ) + with pytest.raises(ValueError, match="not yet implemented"): + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + ) def test_bad_arguments(self): ts = utility_functions.two_tree_mutation_ts() @@ -437,6 +437,13 @@ def test_bad_arguments(self): method="variational_gamma", max_iterations=-1, ) + with pytest.raises(ValueError, match="must be a positive integer"): + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + global_prior=False, + ) def test_match_central_moments(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) diff --git a/tsdate/cli.py b/tsdate/cli.py index e475a53a..76a9eeca 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -199,6 +199,34 @@ def tsdate_cli_parser(): "but does not exactly minimize KL divergence in each EP update." ), ) + parser.add_argument( + "--max-iterations", + type=int, + help=( + "The number of iterations used in the expectation propagation " + "algorithm. Default: 20" + ), + default=20, + ) + parser.add_argument( + "--em-iterations", + type=int, + help=( + "The number of expectation-maximization iterations used to optimize the " + "global mixture prior at the end of each expectation propagation step. " + "Setting to zero disables optimization. Default: 10" + ), + default=10, + ) + parser.add_argument( + "--global-prior", + type=int, + help=( + "The number of components in the i.i.d. mixture prior for node " + "ages. Default: 1" + ), + default=1, + ) parser.set_defaults(runner=run_date) parser = subparsers.add_parser( @@ -253,8 +281,11 @@ def run_date(args): method=args.method, eps=args.epsilon, progress=args.progress, + max_iterations=args.max_iterations, max_shape=args.max_shape, match_central_moments=args.match_central_moments, + em_iterations=args.em_iterations, + global_prior=args.global_prior, ) else: params = dict( diff --git a/tsdate/core.py b/tsdate/core.py index 89d47bce..ae529981 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -40,6 +40,7 @@ from . import approx from . import base from . import demography +from . import mixture from . import prior from . import provenance @@ -953,7 +954,7 @@ class ExpectationPropagation(InOutAlgorithms): Bayesian Inference" """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, global_prior, **kwargs): super().__init__(*args, **kwargs) assert self.priors.probability_space == base.GAMMA_PAR @@ -961,24 +962,29 @@ def __init__(self, *args, **kwargs): assert self.lik.grid_size == 2 assert self.priors.timepoints.size == 2 + # global distribution of node ages + self.global_prior = global_prior.copy() + + # messages passed from prior to nodes + self.prior_messages = np.zeros((self.ts.num_nodes, 2)) + # mutation likelihoods, as gamma natural parameters self.likelihoods = np.zeros((self.ts.num_edges, 2)) for e in self.ts.edges(): self.likelihoods[e.id] = self.lik.to_natural(e) - # messages passed from factors to nodes + # messages passed from edge likelihoods to nodes self.messages = np.zeros((self.ts.num_edges, 2, 2)) - # normalizing constants from each factor + # normalizing constants from each edge likelihood self.log_partition = np.zeros(self.ts.num_edges) # the approximate posterior marginals self.posterior = np.zeros((self.ts.num_nodes, 2)) - for n in self.priors.nonfixed_nodes: - self.posterior[n] = self.priors[n] - # edge traversal order + # edge, node traversal order self.edges, self.leaves = self.factorize(self.ts.edges(), self.fixednodes) + self.freenodes = self.priors.nonfixed_nodes # factors for edges leading from fixed nodes are invariant # and can be incorporated into the posterior beforehand @@ -1012,7 +1018,7 @@ def factorize(edge_list, fixed_nodes): @staticmethod @numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)") - def propagate( + def propagate_likelihood( edges, likelihoods, posterior, @@ -1044,13 +1050,20 @@ def propagate( assert max_shape >= 1.0 - upper = max_shape - 1.0 - lower = 1.0 / max_shape - 1.0 + # Bound the shape parameter for the posterior and cavity distributions + # so that lower_cavi < lower_post < upper_post < upper_cavi. + upper_post = max_shape - 1.0 + lower_post = 1.0 / max_shape - 1.0 + upper_cavi = 2.0 * max_shape - 1.0 + lower_cavi = 0.5 / max_shape - 1.0 def cavity_damping(x, y): + assert upper_cavi > x[0] > lower_cavi d = 1.0 - if (y[0] > 0.0) and (x[0] - y[0] < lower): - d = min(d, (x[0] - lower) / y[0]) + if (y[0] > 0.0) and (x[0] - y[0] < lower_cavi): + d = min(d, (x[0] - lower_cavi) / y[0]) + if (y[0] < 0.0) and (x[0] - y[0] > upper_cavi): + d = min(d, (x[0] - upper_cavi) / y[0]) if (y[1] > 0.0) and (x[1] - y[1] < 0.0): d = min(d, x[1] / y[1]) assert 0.0 < d <= 1.0 @@ -1058,7 +1071,11 @@ def cavity_damping(x, y): def posterior_damping(x): assert x[0] > -1.0 and x[1] > 0.0 - d = min(1.0, upper / abs(x[0])) if (x[0] > 0) else 1.0 + d = 1.0 + if x[0] > upper_post: + d = upper_post / x[0] + if x[0] < lower_post: + d = lower_post / x[0] assert 0.0 < d <= 1.0 return d @@ -1097,13 +1114,81 @@ def posterior_damping(x): return 0.0 # TODO, placeholder - def iterate(self, max_shape=1000, min_kl=True): + @staticmethod + @numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8[:], f8, i4, f8)") + def propagate_prior( + nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol + ): + """TODO + + :param ndarray nodes: ids of nodes that should be updated + :param ndarray global_prior: rows are mixture components, columns are + zeroth, first, and second natural parameters of gamma mixture + components. Updated in place. + :param ndarray posterior: rows are nodes, columns are first and + second natural parameters of gamma posteriors. Updated in + place. + :param ndarray messages: rows are edges, columns are first and + second natural parameters of prior messages. Updated in place. + :param float max_shape: the maximum allowed shape for node posteriors + :param int em_maxitt: the maximum number of EM iterations to use when + fitting the mixture model + :param int em_reltol: the termination criterion for relative change in + log-likelihood + """ + + if global_prior.shape[0] == 0: + return 0.0 + + assert max_shape >= 1.0 + + upper = max_shape - 1.0 + lower = 1.0 / max_shape - 1.0 + + def posterior_damping(x): + assert x[0] > -1.0 and x[1] > 0.0 + d = 1.0 + if x[0] > upper: + d = upper / x[0] + if x[0] < lower: + d = lower / x[0] + assert 0.0 < d <= 1.0 + return d + + cavity = np.zeros(posterior.shape) + cavity[nodes] = posterior[nodes] - messages[nodes] * scale[nodes, np.newaxis] + global_prior, posterior[nodes] = mixture.fit_gamma_mixture( + global_prior, cavity[nodes], em_maxitt, em_reltol, False + ) + messages[nodes] = (posterior[nodes] - cavity[nodes]) / scale[nodes, np.newaxis] + + for n in nodes: + eta = posterior_damping(posterior[n]) + posterior[n] *= eta + scale[n] *= eta + + return 0.0 + + def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True): """ Update edge factors from leaves to root then from root to leaves, and return approximate log marginal likelihood (TODO) """ - self.propagate( + # prior update + self.propagate_prior( + self.freenodes, + self.global_prior, + self.posterior, + self.prior_messages, + self.scale, + max_shape, + em_maxitt, + em_reltol, + ) + + # rootward pass + self.propagate_likelihood( self.edges, self.likelihoods, self.posterior, @@ -1113,7 +1198,9 @@ def iterate(self, max_shape=1000, min_kl=True): max_shape, min_kl, ) - self.propagate( + + # leafward pass + self.propagate_likelihood( self.edges[::-1], self.likelihoods, self.posterior, @@ -1381,7 +1468,8 @@ def variational_dates( max_iterations=20, max_shape=1000, match_central_moments=False, - global_prior=True, + global_prior=1, + em_iterations=10, ): """ Infer dates for the nodes in a tree sequence using expectation propagation, @@ -1390,6 +1478,11 @@ def variational_dates( which invokes this method and inserts the resulting node ages into the tree sequence. + An i.i.d. gamma mixture is used as a prior for each node, that is + initialized from the conditional coalescent and updated via expectation + maximization in each iteration. In addition, node-specific priors may be + specified via a grid of shape/rate parameters. + :param ~tskit.TreeSequence tree_sequence: See :func:`date`. :param float mutation_rate: See :func:`date`. :param float population_size: See :func:`date`. @@ -1397,10 +1490,8 @@ def variational_dates( :param bool progress: See :func:`date`. :param ~tsdate.base.NodeGridValues priors: the prior parameters for each node-to-be-dated, assuming a gamma prior on node - age and using shape/rate parameterization. If ``None`` (default), use - an iid prior derived from the conditional coalescent prior, tilted - according to population size, and assume the nodes to be dated are all - the non-sample nodes in the input tree sequence. + age and using shape/rate parameterization. If ``None`` (default), node + specific priors are omitted and only a global mixture prior is used. :param int max_iterations: The number of iterations used in the expectation propagation algorithm. Default: 20. :param float max_shape: The maximum value for the shape parameter in the variational @@ -1410,9 +1501,11 @@ def variational_dates( update matches mean and variance rather than expected gamma sufficient statistics. Faster with a similar accuracy, but does not exactly minimize Kullback-Leibler divergence. Default: False. - :param bool global_prior: If `True`, an iid prior is used for all nodes, - and is constructed by averaging gamma sufficient statistics over the free - nodes in `priors`. Default: True. + :param int global_prior: The number of components in the i.i.d. mixture prior + for node ages. Default: 1. + :param int em_iterations: The number of expectation maximization iterations used + to optimize the global mixture prior. Setting to zero disables optimization. + Default: 10. :return: a tuple ``(mn_post, va_post, posteriors, nodes_to_date)``, where: ``mn_post`` (:class:`~numpy.ndarray`) and ``va_post`` (:class:`~numpy.ndarray`) @@ -1436,45 +1529,37 @@ def variational_dates( if mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") + if not (isinstance(global_prior, int) and global_prior > 0): + raise ValueError("'global_prior' must be a positive integer") + # Default to not creating approximate priors unless ts has # greater than DEFAULT_APPROX_PRIOR_SIZE samples approx_priors = False if tree_sequence.num_samples > base.DEFAULT_APPROX_PRIOR_SIZE: approx_priors = True - if priors is None: - if population_size is None: - raise ValueError( - "Must specify population size if priors are not already " - "built using tsdate.build_parameter_grid()" - ) - priors = prior.parameter_grid( - tree_sequence, - population_size=population_size, - progress=progress, - approximate_priors=approx_priors, + # TODO: support additional node-specific priors + if priors is not None: + raise ValueError( + "Node-specific priors not yet implemented with method 'variational_gamma'" ) - else: - logging.info("Using user-specified priors") - if population_size is not None: - raise ValueError( - "Cannot specify population size in tsdate.date() or " - "tsdate.variational_dates() if specifying priors from " - "tsdate.build_parameter_grid()" - ) - priors = priors - # convert priors to natural parameterization and average - for n in priors.nonfixed_nodes: - priors[n][0] -= 1.0 - assert priors[n][0] > -1.0 - assert priors[n][1] >= 0.0 - if global_prior: - logging.info("Pooling node-specific priors into global prior") - priors.grid_data[:] = approx.average_gammas( - priors.grid_data[:, 0], priors.grid_data[:, 1] + # initialize global prior from coalescent + if population_size is None: + raise ValueError( + "Must specify population size if priors are not already " + "built using tsdate.build_parameter_grid()" ) + priors = prior.parameter_grid( + tree_sequence, + population_size=population_size, + progress=progress, + approximate_priors=approx_priors, + ) + prior_mixture = mixture.initialize_mixture(priors.grid_data, global_prior) + priors.grid_data[:] = [0.0, 0.0] + # calculate mutation likelihoods per edge liklhd = VariationalLikelihoods( tree_sequence, mutation_rate, @@ -1485,14 +1570,19 @@ def variational_dates( # match sufficient statistics or match central moments min_kl = not match_central_moments - dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress) - for _ in tqdm( + dynamic_prog = ExpectationPropagation( + priors, liklhd, progress=progress, global_prior=prior_mixture + ) + for itt in tqdm( np.arange(max_iterations), desc="Expectation Propagation", disable=not progress, ): - dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl) - + dynamic_prog.iterate( + em_maxitt=em_iterations if itt else 0, + max_shape=max_shape, + min_kl=min_kl, + ) num_skipped = np.sum(np.isnan(dynamic_prog.log_partition)) if num_skipped > 0: logging.info(f"Skipped {num_skipped} messages with invalid posterior updates.") diff --git a/tsdate/mixture.py b/tsdate/mixture.py new file mode 100644 index 00000000..51538b9f --- /dev/null +++ b/tsdate/mixture.py @@ -0,0 +1,240 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# Copyright (c) 2020-21 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Mixture of gamma distributions that may be fit via EM to distribution-valued observations +""" +import numba +import numpy as np + +from . import approx +from . import hypergeo + + +@numba.njit("UniTuple(f8[:], 4)(f8[:], f8[:], f8[:], f8, f8)") +def _conditional_posterior(prior_logweight, prior_alpha, prior_beta, alpha, beta): + r""" + Return expectations of node age :math:`t` from the mixture model, + + ..math:: + + Ga(t | a, b) \sum_j \pi_j w_j Ga(t | \alpha_j, \beta_j) + + where :math:`a` and :math:`b` are variational parameters, + and :math:`\pi_j, \alpha_j, \beta_j` are prior weights and + parameters for a gamma mixture; and :math:`w_j` are fixed, + observation-specific weights. We use natural parameterization, + so that the shape parameter is :math:`\alpha + 1`. + + TODO: + The normalizing constants of the prior are assumed to have already + been integrated into `prior_weight`. + + Returns the contribution from each component to the + posterior expectations of :math:`E[1]`, :math:`E[t]`, :math:`E[log t]`, + and :math:`E[t log t]`. + + Note that :math:`E[1]` is *unnormalized* and *log-transformed*. + """ + + dim = prior_logweight.size + E = np.full(dim, -np.inf) # E[1] (e.g. normalizing constant) + E_t = np.zeros(dim) # E[t] + E_logt = np.zeros(dim) # E[log(t)] + E_tlogt = np.zeros(dim) # E[t * log(t)] + C = (alpha + 1) * np.log(beta) - hypergeo._gammaln(alpha + 1) if beta > 0 else 0.0 + for i in range(dim): + post_alpha = prior_alpha[i] + alpha + post_beta = prior_beta[i] + beta + if (post_alpha <= -1) or (post_beta <= 0): # skip node if divergent + E[:] = -np.inf + break + E[i] = C + ( + +hypergeo._gammaln(post_alpha + 1) + - (post_alpha + 1) * np.log(post_beta) + + prior_logweight[i] + ) + assert np.isfinite(E[i]) + # TODO: option to use moments instead of sufficient statistics? + E_t[i] = (post_alpha + 1) / post_beta + E_logt[i] = hypergeo._digamma(post_alpha + 1) - np.log(post_beta) + E_tlogt[i] = E_t[i] * E_logt[i] + E_t[i] / (post_alpha + 1) + + return E, E_t, E_logt, E_tlogt + + +@numba.njit("f8(f8[:], f8[:], f8[:], f8[:], f8[:])") +def _em_update(prior_weight, prior_alpha, prior_beta, alpha, beta): + """ + Perform an expectation maximization step for parameters of mixture components, + given variational parameters `alpha`, `beta` for each node. + + The maximization step is performed using Ye & Chen (2017) "Closed form + estimators for the gamma distribution ..." + + ``prior_weight``, ``prior_alpha``, ``prior_beta`` are updated in place. + """ + assert alpha.size == beta.size + + dim = prior_weight.size + n = np.zeros(dim) + t = np.zeros(dim) + logt = np.zeros(dim) + tlogt = np.zeros(dim) + + # incorporate prior normalizing constants into weights + prior_logweight = np.log(prior_weight) + for k in range(dim): + prior_logweight[k] += (prior_alpha[k] + 1) * np.log( + prior_beta[k] + ) - hypergeo._gammaln(prior_alpha[k] + 1) + + # expectation step: + loglik = 0.0 + for a, b in zip(alpha, beta): + E, E_t, E_logt, E_tlogt = _conditional_posterior( + prior_logweight, prior_alpha, prior_beta, a, b + ) + + # skip if posterior is improper + if np.any(np.isinf(E)): + continue + + # convert evidence to posterior weights + norm_const = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) + weight = np.exp(E - norm_const) + + # weighted contributions to sufficient statistics + loglik += norm_const + n += weight + t += E_t * weight + logt += E_logt * weight + tlogt += E_tlogt * weight + + # maximization step: update parameters in place + prior_weight[:] = n / np.sum(n) + prior_beta[:] = n**2 / (n * tlogt - t * logt) + prior_alpha[:] = n * t / (n * tlogt - t * logt) - 1.0 + + return loglik + + +@numba.njit("f8[:](f8[:], f8[:], f8[:], f8[:], f8[:])") +def _gamma_projection(prior_weight, prior_alpha, prior_beta, alpha, beta): + """ + Given variational approximation to posterior: multiply by exact prior, + calculate sufficient statistics, and moment match to get new + approximate posterior. + + Updates ``alpha`` and ``beta`` in-place. + """ + assert alpha.size == beta.size + + dim = prior_weight.size + + # incorporate prior normalizing constants into weights + prior_logweight = np.log(prior_weight) + for k in range(dim): + prior_logweight[k] += (prior_alpha[k] + 1) * np.log( + prior_beta[k] + ) - hypergeo._gammaln(prior_alpha[k] + 1) + + log_const = np.full(alpha.size, -np.inf) + for i in range(alpha.size): + E, E_t, E_logt, E_tlogt = _conditional_posterior( + prior_logweight, prior_alpha, prior_beta, alpha[i], beta[i] + ) + + # skip if posterior is improper for all components + if np.any(np.isinf(E)): + continue + + norm = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) + weight = np.exp(E - norm) + t = np.sum(weight * E_t) + logt = np.sum(weight * E_logt) + # tlogt = np.sum(weight * E_tlogt) + log_const[i] = norm + alpha[i], beta[i] = approx.approximate_gamma_kl(t, logt) + # beta[i] = 1 / (tlogt - t * logt) + # alpha[i] = t * beta[i] - 1.0 + + return log_const + + +@numba.njit("Tuple((f8[:,:], f8[:,:]))(f8[:,:], f8[:,:], i4, f8, b1)") +def fit_gamma_mixture(mixture, observations, max_iterations, tolerance, verbose): + """ + Run EM until relative tolerance or maximum number of iterations is + reached. Then, perform expectation-propagation update and return new + variational parameters for the posterior approximation. + """ + + assert mixture.shape[1] == 3 + assert observations.shape[1] == 2 + + mix_weight, mix_alpha, mix_beta = mixture.T + alpha, beta = observations.T + + last = np.inf + for itt in range(max_iterations): + loglik = _em_update(mix_weight, mix_alpha, mix_beta, alpha, beta) + loglik /= float(alpha.size) + update = np.abs(loglik - last) + last = loglik + if verbose: + print("EM iteration:", itt, "mean(loglik):", np.round(loglik, 5)) + print(" -> weights:", mix_weight) + print(" -> alpha:", mix_alpha) + print(" -> beta:", mix_beta) + if update < np.abs(loglik) * tolerance: + break + + # conditional posteriors for each observation + # log_const = _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) + _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) + + new_mixture = np.zeros(mixture.shape) + new_observations = np.zeros(observations.shape) + new_observations[:, 0] = alpha + new_observations[:, 1] = beta + new_mixture[:, 0] = mix_weight + new_mixture[:, 1] = mix_alpha + new_mixture[:, 2] = mix_beta + + return new_mixture, new_observations + + +def initialize_mixture(parameters, num_components): + """initialize clusters by dividing nodes into equal groups""" + global_prior = np.empty((num_components, 3)) + num_nodes = parameters.shape[0] + age_classes = np.tile(np.arange(num_components), num_nodes // num_components + 1)[ + :num_nodes + ] + for k in range(num_components): + indices = np.equal(age_classes, k) + alpha, beta = approx.average_gammas( + parameters[indices, 0] - 1.0, parameters[indices, 1] + ) + global_prior[k] = [1.0 / num_components, alpha, beta] + return global_prior diff --git a/tsdate/prior.py b/tsdate/prior.py index a842a1ca..a7eda75e 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -30,6 +30,8 @@ import numba import numpy as np +import scipy.cluster +import scipy.special import scipy.stats import tskit from tqdm.auto import tqdm