diff --git a/tests/test_functions.py b/tests/test_functions.py index 01280a0c..055e4436 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -799,13 +799,6 @@ def test_variational_nosize(self): with pytest.raises(ValueError, match="Must specify population size"): variational_dates(ts, mutation_rate=1) - def test_variational_toomanysizes(self): - ts = utility_functions.two_tree_mutation_ts() - Ne = 1 - priors = tsdate.build_prior_grid(ts, Ne, np.array([0, 1.2, 2])) - with pytest.raises(ValueError, match="Cannot specify"): - variational_dates(ts, mutation_rate=1, population_size=Ne, priors=priors) - class TestNodeGridValuesClass: def test_init(self): diff --git a/tests/test_inference.py b/tests/test_inference.py index 2d31a963..017b473e 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -419,13 +419,13 @@ def test_nonglobal_priors(self): priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") grid = priors.make_parameter_grid(population_size=1) grid.grid_data[:] = [1.0, 0.0] # noninformative prior - tsdate.date( - ts, - mutation_rate=5, - method="variational_gamma", - priors=grid, - global_prior=False, - ) + with pytest.raises(ValueError, match="not yet implemented"): + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + ) def test_bad_arguments(self): ts = utility_functions.two_tree_mutation_ts() @@ -437,6 +437,13 @@ def test_bad_arguments(self): method="variational_gamma", max_iterations=-1, ) + with pytest.raises(ValueError, match="must be a positive integer"): + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + global_prior=False, + ) def test_match_central_moments(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) diff --git a/tsdate/cli.py b/tsdate/cli.py index e475a53a..76a9eeca 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -199,6 +199,34 @@ def tsdate_cli_parser(): "but does not exactly minimize KL divergence in each EP update." ), ) + parser.add_argument( + "--max-iterations", + type=int, + help=( + "The number of iterations used in the expectation propagation " + "algorithm. Default: 20" + ), + default=20, + ) + parser.add_argument( + "--em-iterations", + type=int, + help=( + "The number of expectation-maximization iterations used to optimize the " + "global mixture prior at the end of each expectation propagation step. " + "Setting to zero disables optimization. Default: 10" + ), + default=10, + ) + parser.add_argument( + "--global-prior", + type=int, + help=( + "The number of components in the i.i.d. mixture prior for node " + "ages. Default: 1" + ), + default=1, + ) parser.set_defaults(runner=run_date) parser = subparsers.add_parser( @@ -253,8 +281,11 @@ def run_date(args): method=args.method, eps=args.epsilon, progress=args.progress, + max_iterations=args.max_iterations, max_shape=args.max_shape, match_central_moments=args.match_central_moments, + em_iterations=args.em_iterations, + global_prior=args.global_prior, ) else: params = dict( diff --git a/tsdate/core.py b/tsdate/core.py index 89d47bce..ae529981 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -40,6 +40,7 @@ from . import approx from . import base from . import demography +from . import mixture from . import prior from . import provenance @@ -953,7 +954,7 @@ class ExpectationPropagation(InOutAlgorithms): Bayesian Inference" """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, global_prior, **kwargs): super().__init__(*args, **kwargs) assert self.priors.probability_space == base.GAMMA_PAR @@ -961,24 +962,29 @@ def __init__(self, *args, **kwargs): assert self.lik.grid_size == 2 assert self.priors.timepoints.size == 2 + # global distribution of node ages + self.global_prior = global_prior.copy() + + # messages passed from prior to nodes + self.prior_messages = np.zeros((self.ts.num_nodes, 2)) + # mutation likelihoods, as gamma natural parameters self.likelihoods = np.zeros((self.ts.num_edges, 2)) for e in self.ts.edges(): self.likelihoods[e.id] = self.lik.to_natural(e) - # messages passed from factors to nodes + # messages passed from edge likelihoods to nodes self.messages = np.zeros((self.ts.num_edges, 2, 2)) - # normalizing constants from each factor + # normalizing constants from each edge likelihood self.log_partition = np.zeros(self.ts.num_edges) # the approximate posterior marginals self.posterior = np.zeros((self.ts.num_nodes, 2)) - for n in self.priors.nonfixed_nodes: - self.posterior[n] = self.priors[n] - # edge traversal order + # edge, node traversal order self.edges, self.leaves = self.factorize(self.ts.edges(), self.fixednodes) + self.freenodes = self.priors.nonfixed_nodes # factors for edges leading from fixed nodes are invariant # and can be incorporated into the posterior beforehand @@ -1012,7 +1018,7 @@ def factorize(edge_list, fixed_nodes): @staticmethod @numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)") - def propagate( + def propagate_likelihood( edges, likelihoods, posterior, @@ -1044,13 +1050,20 @@ def propagate( assert max_shape >= 1.0 - upper = max_shape - 1.0 - lower = 1.0 / max_shape - 1.0 + # Bound the shape parameter for the posterior and cavity distributions + # so that lower_cavi < lower_post < upper_post < upper_cavi. + upper_post = max_shape - 1.0 + lower_post = 1.0 / max_shape - 1.0 + upper_cavi = 2.0 * max_shape - 1.0 + lower_cavi = 0.5 / max_shape - 1.0 def cavity_damping(x, y): + assert upper_cavi > x[0] > lower_cavi d = 1.0 - if (y[0] > 0.0) and (x[0] - y[0] < lower): - d = min(d, (x[0] - lower) / y[0]) + if (y[0] > 0.0) and (x[0] - y[0] < lower_cavi): + d = min(d, (x[0] - lower_cavi) / y[0]) + if (y[0] < 0.0) and (x[0] - y[0] > upper_cavi): + d = min(d, (x[0] - upper_cavi) / y[0]) if (y[1] > 0.0) and (x[1] - y[1] < 0.0): d = min(d, x[1] / y[1]) assert 0.0 < d <= 1.0 @@ -1058,7 +1071,11 @@ def cavity_damping(x, y): def posterior_damping(x): assert x[0] > -1.0 and x[1] > 0.0 - d = min(1.0, upper / abs(x[0])) if (x[0] > 0) else 1.0 + d = 1.0 + if x[0] > upper_post: + d = upper_post / x[0] + if x[0] < lower_post: + d = lower_post / x[0] assert 0.0 < d <= 1.0 return d @@ -1097,13 +1114,81 @@ def posterior_damping(x): return 0.0 # TODO, placeholder - def iterate(self, max_shape=1000, min_kl=True): + @staticmethod + @numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8[:], f8, i4, f8)") + def propagate_prior( + nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol + ): + """TODO + + :param ndarray nodes: ids of nodes that should be updated + :param ndarray global_prior: rows are mixture components, columns are + zeroth, first, and second natural parameters of gamma mixture + components. Updated in place. + :param ndarray posterior: rows are nodes, columns are first and + second natural parameters of gamma posteriors. Updated in + place. + :param ndarray messages: rows are edges, columns are first and + second natural parameters of prior messages. Updated in place. + :param float max_shape: the maximum allowed shape for node posteriors + :param int em_maxitt: the maximum number of EM iterations to use when + fitting the mixture model + :param int em_reltol: the termination criterion for relative change in + log-likelihood + """ + + if global_prior.shape[0] == 0: + return 0.0 + + assert max_shape >= 1.0 + + upper = max_shape - 1.0 + lower = 1.0 / max_shape - 1.0 + + def posterior_damping(x): + assert x[0] > -1.0 and x[1] > 0.0 + d = 1.0 + if x[0] > upper: + d = upper / x[0] + if x[0] < lower: + d = lower / x[0] + assert 0.0 < d <= 1.0 + return d + + cavity = np.zeros(posterior.shape) + cavity[nodes] = posterior[nodes] - messages[nodes] * scale[nodes, np.newaxis] + global_prior, posterior[nodes] = mixture.fit_gamma_mixture( + global_prior, cavity[nodes], em_maxitt, em_reltol, False + ) + messages[nodes] = (posterior[nodes] - cavity[nodes]) / scale[nodes, np.newaxis] + + for n in nodes: + eta = posterior_damping(posterior[n]) + posterior[n] *= eta + scale[n] *= eta + + return 0.0 + + def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True): """ Update edge factors from leaves to root then from root to leaves, and return approximate log marginal likelihood (TODO) """ - self.propagate( + # prior update + self.propagate_prior( + self.freenodes, + self.global_prior, + self.posterior, + self.prior_messages, + self.scale, + max_shape, + em_maxitt, + em_reltol, + ) + + # rootward pass + self.propagate_likelihood( self.edges, self.likelihoods, self.posterior, @@ -1113,7 +1198,9 @@ def iterate(self, max_shape=1000, min_kl=True): max_shape, min_kl, ) - self.propagate( + + # leafward pass + self.propagate_likelihood( self.edges[::-1], self.likelihoods, self.posterior, @@ -1381,7 +1468,8 @@ def variational_dates( max_iterations=20, max_shape=1000, match_central_moments=False, - global_prior=True, + global_prior=1, + em_iterations=10, ): """ Infer dates for the nodes in a tree sequence using expectation propagation, @@ -1390,6 +1478,11 @@ def variational_dates( which invokes this method and inserts the resulting node ages into the tree sequence. + An i.i.d. gamma mixture is used as a prior for each node, that is + initialized from the conditional coalescent and updated via expectation + maximization in each iteration. In addition, node-specific priors may be + specified via a grid of shape/rate parameters. + :param ~tskit.TreeSequence tree_sequence: See :func:`date`. :param float mutation_rate: See :func:`date`. :param float population_size: See :func:`date`. @@ -1397,10 +1490,8 @@ def variational_dates( :param bool progress: See :func:`date`. :param ~tsdate.base.NodeGridValues priors: the prior parameters for each node-to-be-dated, assuming a gamma prior on node - age and using shape/rate parameterization. If ``None`` (default), use - an iid prior derived from the conditional coalescent prior, tilted - according to population size, and assume the nodes to be dated are all - the non-sample nodes in the input tree sequence. + age and using shape/rate parameterization. If ``None`` (default), node + specific priors are omitted and only a global mixture prior is used. :param int max_iterations: The number of iterations used in the expectation propagation algorithm. Default: 20. :param float max_shape: The maximum value for the shape parameter in the variational @@ -1410,9 +1501,11 @@ def variational_dates( update matches mean and variance rather than expected gamma sufficient statistics. Faster with a similar accuracy, but does not exactly minimize Kullback-Leibler divergence. Default: False. - :param bool global_prior: If `True`, an iid prior is used for all nodes, - and is constructed by averaging gamma sufficient statistics over the free - nodes in `priors`. Default: True. + :param int global_prior: The number of components in the i.i.d. mixture prior + for node ages. Default: 1. + :param int em_iterations: The number of expectation maximization iterations used + to optimize the global mixture prior. Setting to zero disables optimization. + Default: 10. :return: a tuple ``(mn_post, va_post, posteriors, nodes_to_date)``, where: ``mn_post`` (:class:`~numpy.ndarray`) and ``va_post`` (:class:`~numpy.ndarray`) @@ -1436,45 +1529,37 @@ def variational_dates( if mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") + if not (isinstance(global_prior, int) and global_prior > 0): + raise ValueError("'global_prior' must be a positive integer") + # Default to not creating approximate priors unless ts has # greater than DEFAULT_APPROX_PRIOR_SIZE samples approx_priors = False if tree_sequence.num_samples > base.DEFAULT_APPROX_PRIOR_SIZE: approx_priors = True - if priors is None: - if population_size is None: - raise ValueError( - "Must specify population size if priors are not already " - "built using tsdate.build_parameter_grid()" - ) - priors = prior.parameter_grid( - tree_sequence, - population_size=population_size, - progress=progress, - approximate_priors=approx_priors, + # TODO: support additional node-specific priors + if priors is not None: + raise ValueError( + "Node-specific priors not yet implemented with method 'variational_gamma'" ) - else: - logging.info("Using user-specified priors") - if population_size is not None: - raise ValueError( - "Cannot specify population size in tsdate.date() or " - "tsdate.variational_dates() if specifying priors from " - "tsdate.build_parameter_grid()" - ) - priors = priors - # convert priors to natural parameterization and average - for n in priors.nonfixed_nodes: - priors[n][0] -= 1.0 - assert priors[n][0] > -1.0 - assert priors[n][1] >= 0.0 - if global_prior: - logging.info("Pooling node-specific priors into global prior") - priors.grid_data[:] = approx.average_gammas( - priors.grid_data[:, 0], priors.grid_data[:, 1] + # initialize global prior from coalescent + if population_size is None: + raise ValueError( + "Must specify population size if priors are not already " + "built using tsdate.build_parameter_grid()" ) + priors = prior.parameter_grid( + tree_sequence, + population_size=population_size, + progress=progress, + approximate_priors=approx_priors, + ) + prior_mixture = mixture.initialize_mixture(priors.grid_data, global_prior) + priors.grid_data[:] = [0.0, 0.0] + # calculate mutation likelihoods per edge liklhd = VariationalLikelihoods( tree_sequence, mutation_rate, @@ -1485,14 +1570,19 @@ def variational_dates( # match sufficient statistics or match central moments min_kl = not match_central_moments - dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress) - for _ in tqdm( + dynamic_prog = ExpectationPropagation( + priors, liklhd, progress=progress, global_prior=prior_mixture + ) + for itt in tqdm( np.arange(max_iterations), desc="Expectation Propagation", disable=not progress, ): - dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl) - + dynamic_prog.iterate( + em_maxitt=em_iterations if itt else 0, + max_shape=max_shape, + min_kl=min_kl, + ) num_skipped = np.sum(np.isnan(dynamic_prog.log_partition)) if num_skipped > 0: logging.info(f"Skipped {num_skipped} messages with invalid posterior updates.") diff --git a/tsdate/mixture.py b/tsdate/mixture.py new file mode 100644 index 00000000..51538b9f --- /dev/null +++ b/tsdate/mixture.py @@ -0,0 +1,240 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# Copyright (c) 2020-21 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Mixture of gamma distributions that may be fit via EM to distribution-valued observations +""" +import numba +import numpy as np + +from . import approx +from . import hypergeo + + +@numba.njit("UniTuple(f8[:], 4)(f8[:], f8[:], f8[:], f8, f8)") +def _conditional_posterior(prior_logweight, prior_alpha, prior_beta, alpha, beta): + r""" + Return expectations of node age :math:`t` from the mixture model, + + ..math:: + + Ga(t | a, b) \sum_j \pi_j w_j Ga(t | \alpha_j, \beta_j) + + where :math:`a` and :math:`b` are variational parameters, + and :math:`\pi_j, \alpha_j, \beta_j` are prior weights and + parameters for a gamma mixture; and :math:`w_j` are fixed, + observation-specific weights. We use natural parameterization, + so that the shape parameter is :math:`\alpha + 1`. + + TODO: + The normalizing constants of the prior are assumed to have already + been integrated into `prior_weight`. + + Returns the contribution from each component to the + posterior expectations of :math:`E[1]`, :math:`E[t]`, :math:`E[log t]`, + and :math:`E[t log t]`. + + Note that :math:`E[1]` is *unnormalized* and *log-transformed*. + """ + + dim = prior_logweight.size + E = np.full(dim, -np.inf) # E[1] (e.g. normalizing constant) + E_t = np.zeros(dim) # E[t] + E_logt = np.zeros(dim) # E[log(t)] + E_tlogt = np.zeros(dim) # E[t * log(t)] + C = (alpha + 1) * np.log(beta) - hypergeo._gammaln(alpha + 1) if beta > 0 else 0.0 + for i in range(dim): + post_alpha = prior_alpha[i] + alpha + post_beta = prior_beta[i] + beta + if (post_alpha <= -1) or (post_beta <= 0): # skip node if divergent + E[:] = -np.inf + break + E[i] = C + ( + +hypergeo._gammaln(post_alpha + 1) + - (post_alpha + 1) * np.log(post_beta) + + prior_logweight[i] + ) + assert np.isfinite(E[i]) + # TODO: option to use moments instead of sufficient statistics? + E_t[i] = (post_alpha + 1) / post_beta + E_logt[i] = hypergeo._digamma(post_alpha + 1) - np.log(post_beta) + E_tlogt[i] = E_t[i] * E_logt[i] + E_t[i] / (post_alpha + 1) + + return E, E_t, E_logt, E_tlogt + + +@numba.njit("f8(f8[:], f8[:], f8[:], f8[:], f8[:])") +def _em_update(prior_weight, prior_alpha, prior_beta, alpha, beta): + """ + Perform an expectation maximization step for parameters of mixture components, + given variational parameters `alpha`, `beta` for each node. + + The maximization step is performed using Ye & Chen (2017) "Closed form + estimators for the gamma distribution ..." + + ``prior_weight``, ``prior_alpha``, ``prior_beta`` are updated in place. + """ + assert alpha.size == beta.size + + dim = prior_weight.size + n = np.zeros(dim) + t = np.zeros(dim) + logt = np.zeros(dim) + tlogt = np.zeros(dim) + + # incorporate prior normalizing constants into weights + prior_logweight = np.log(prior_weight) + for k in range(dim): + prior_logweight[k] += (prior_alpha[k] + 1) * np.log( + prior_beta[k] + ) - hypergeo._gammaln(prior_alpha[k] + 1) + + # expectation step: + loglik = 0.0 + for a, b in zip(alpha, beta): + E, E_t, E_logt, E_tlogt = _conditional_posterior( + prior_logweight, prior_alpha, prior_beta, a, b + ) + + # skip if posterior is improper + if np.any(np.isinf(E)): + continue + + # convert evidence to posterior weights + norm_const = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) + weight = np.exp(E - norm_const) + + # weighted contributions to sufficient statistics + loglik += norm_const + n += weight + t += E_t * weight + logt += E_logt * weight + tlogt += E_tlogt * weight + + # maximization step: update parameters in place + prior_weight[:] = n / np.sum(n) + prior_beta[:] = n**2 / (n * tlogt - t * logt) + prior_alpha[:] = n * t / (n * tlogt - t * logt) - 1.0 + + return loglik + + +@numba.njit("f8[:](f8[:], f8[:], f8[:], f8[:], f8[:])") +def _gamma_projection(prior_weight, prior_alpha, prior_beta, alpha, beta): + """ + Given variational approximation to posterior: multiply by exact prior, + calculate sufficient statistics, and moment match to get new + approximate posterior. + + Updates ``alpha`` and ``beta`` in-place. + """ + assert alpha.size == beta.size + + dim = prior_weight.size + + # incorporate prior normalizing constants into weights + prior_logweight = np.log(prior_weight) + for k in range(dim): + prior_logweight[k] += (prior_alpha[k] + 1) * np.log( + prior_beta[k] + ) - hypergeo._gammaln(prior_alpha[k] + 1) + + log_const = np.full(alpha.size, -np.inf) + for i in range(alpha.size): + E, E_t, E_logt, E_tlogt = _conditional_posterior( + prior_logweight, prior_alpha, prior_beta, alpha[i], beta[i] + ) + + # skip if posterior is improper for all components + if np.any(np.isinf(E)): + continue + + norm = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) + weight = np.exp(E - norm) + t = np.sum(weight * E_t) + logt = np.sum(weight * E_logt) + # tlogt = np.sum(weight * E_tlogt) + log_const[i] = norm + alpha[i], beta[i] = approx.approximate_gamma_kl(t, logt) + # beta[i] = 1 / (tlogt - t * logt) + # alpha[i] = t * beta[i] - 1.0 + + return log_const + + +@numba.njit("Tuple((f8[:,:], f8[:,:]))(f8[:,:], f8[:,:], i4, f8, b1)") +def fit_gamma_mixture(mixture, observations, max_iterations, tolerance, verbose): + """ + Run EM until relative tolerance or maximum number of iterations is + reached. Then, perform expectation-propagation update and return new + variational parameters for the posterior approximation. + """ + + assert mixture.shape[1] == 3 + assert observations.shape[1] == 2 + + mix_weight, mix_alpha, mix_beta = mixture.T + alpha, beta = observations.T + + last = np.inf + for itt in range(max_iterations): + loglik = _em_update(mix_weight, mix_alpha, mix_beta, alpha, beta) + loglik /= float(alpha.size) + update = np.abs(loglik - last) + last = loglik + if verbose: + print("EM iteration:", itt, "mean(loglik):", np.round(loglik, 5)) + print(" -> weights:", mix_weight) + print(" -> alpha:", mix_alpha) + print(" -> beta:", mix_beta) + if update < np.abs(loglik) * tolerance: + break + + # conditional posteriors for each observation + # log_const = _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) + _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) + + new_mixture = np.zeros(mixture.shape) + new_observations = np.zeros(observations.shape) + new_observations[:, 0] = alpha + new_observations[:, 1] = beta + new_mixture[:, 0] = mix_weight + new_mixture[:, 1] = mix_alpha + new_mixture[:, 2] = mix_beta + + return new_mixture, new_observations + + +def initialize_mixture(parameters, num_components): + """initialize clusters by dividing nodes into equal groups""" + global_prior = np.empty((num_components, 3)) + num_nodes = parameters.shape[0] + age_classes = np.tile(np.arange(num_components), num_nodes // num_components + 1)[ + :num_nodes + ] + for k in range(num_components): + indices = np.equal(age_classes, k) + alpha, beta = approx.average_gammas( + parameters[indices, 0] - 1.0, parameters[indices, 1] + ) + global_prior[k] = [1.0 / num_components, alpha, beta] + return global_prior diff --git a/tsdate/prior.py b/tsdate/prior.py index a842a1ca..a7eda75e 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -30,6 +30,8 @@ import numba import numpy as np +import scipy.cluster +import scipy.special import scipy.stats import tskit from tqdm.auto import tqdm