diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index c930648..81d705e 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -54,6 +54,7 @@ include("arraydist.jl") include("filldist.jl") include("univariate.jl") include("multivariate.jl") +include("mixturemodels.jl") include("mvcategorical.jl") include("matrixvariate.jl") include("flatten.jl") diff --git a/src/mixturemodels.jl b/src/mixturemodels.jl new file mode 100644 index 0000000..e30d875 --- /dev/null +++ b/src/mixturemodels.jl @@ -0,0 +1,21 @@ +function _mixlogpdf1(d::AbstractMixtureModel, x) + # using the formula below for numerical stability + # + # logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x)) + # = log(sum_i pri[i] * exp(logpdf(cs[i], x))) + # = log(sum_i exp(logpri[i] + logpdf(cs[i], x))) + + pri = probs(d) + indices = findall(!iszero, pri) + lp = map(indices) do i + return logpdf(component(d, i), x) + log(pri[i]) + end + + return logsumexp(lp) +end + + + +Distributions.logpdf(d::UnivariateMixture{Continuous}, x::Real) = _mixlogpdf1(d, x) +Distributions.logpdf(d::UnivariateMixture{Discrete}, x::Int) = _mixlogpdf1(d, x) +Distributions._logpdf(d::MultivariateMixture, x::AbstractVector) = _mixlogpdf1(d, x) \ No newline at end of file