From dfadb21de4209a25b35a0da817bb611285a773c2 Mon Sep 17 00:00:00 2001 From: Siva Prasad Varma Date: Mon, 11 May 2020 13:13:50 -0700 Subject: [PATCH 1/4] mixture model logpdf using Buffer --- src/DistributionsAD.jl | 2 ++ src/mixturemodels.jl | 46 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 src/mixturemodels.jl diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index ca6342af..f3b4ecb0 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -36,6 +36,7 @@ import Distributions: MvNormal, BetaBinomial, Erlang import ZygoteRules +import Zygote export TuringScalMvNormal, TuringDiagMvNormal, @@ -50,6 +51,7 @@ export TuringScalMvNormal, include("common.jl") include("univariate.jl") include("multivariate.jl") +include("mixturemodels.jl") include("mvcategorical.jl") include("matrixvariate.jl") include("flatten.jl") diff --git a/src/mixturemodels.jl b/src/mixturemodels.jl new file mode 100644 index 00000000..844af073 --- /dev/null +++ b/src/mixturemodels.jl @@ -0,0 +1,46 @@ +function _mixlogpdf1(d::AbstractMixtureModel, x) + # using the formula below for numerical stability + # + # logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x)) + # = log(sum_i pri[i] * exp(logpdf(cs[i], x))) + # = log(sum_i exp(logpri[i] + logpdf(cs[i], x))) + # = m + log(sum_i exp(logpri[i] + logpdf(cs[i], x) - m)) + # + # m is chosen to be the maximum of logpri[i] + logpdf(cs[i], x) + # such that the argument of exp is in a reasonable range + # + + K = ncomponents(d) + p = probs(d) + # use Buffer to avoid mutating arrays. + # lp = Vector{eltype(p)}(undef, K) + lp = Zygote.Buffer(p, K) + m = -Inf # m <- the maximum of log(p(cs[i], x)) + log(pri[i]) + @inbounds for i in eachindex(p) + pi = p[i] + if pi > 0.0 + # lp[i] <- log(p(cs[i], x)) + log(pri[i]) + lp_i = logpdf(component(d, i), x) + log(pi) + # zygote seems to have trouble here. + # Mutating arrays is not supported + lp[i] = lp_i + if lp_i > m + m = lp_i + end + end + end + v = 0.0 + @inbounds for i = 1:K + if p[i] > 0.0 + v += exp(lp[i] - m) + end + end + return m + log(v) +end + + + +Distributions.logpdf(d::UnivariateMixture{Continuous}, x::Real) = _mixlogpdf1(d, x) +Distributions.logpdf(d::UnivariateMixture{Discrete}, x::Int) = _mixlogpdf1(d, x) + +Distributions._logpdf(d::MultivariateMixture, x::AbstractVector) = _mixlogpdf1(d, x) \ No newline at end of file From 57c75b40c19d4797b544e4510cb42a993cf2f792 Mon Sep 17 00:00:00 2001 From: Siva Prasad Varma Date: Mon, 11 May 2020 13:46:46 -0700 Subject: [PATCH 2/4] add zygote to dependencies --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 9d5cd6e5..bfbb84a9 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] From f5fa9051f724b98a8719c9db612665f346bc449f Mon Sep 17 00:00:00 2001 From: Siva Prasad Varma Date: Mon, 11 May 2020 15:53:17 -0700 Subject: [PATCH 3/4] use map do block, remove zygote dependency --- Project.toml | 1 - src/DistributionsAD.jl | 1 - src/mixturemodels.jl | 36 +++++++----------------------------- 3 files changed, 7 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index bfbb84a9..9d5cd6e5 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index f3b4ecb0..ab89d693 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -36,7 +36,6 @@ import Distributions: MvNormal, BetaBinomial, Erlang import ZygoteRules -import Zygote export TuringScalMvNormal, TuringDiagMvNormal, diff --git a/src/mixturemodels.jl b/src/mixturemodels.jl index 844af073..ca942a59 100644 --- a/src/mixturemodels.jl +++ b/src/mixturemodels.jl @@ -4,43 +4,21 @@ function _mixlogpdf1(d::AbstractMixtureModel, x) # logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x)) # = log(sum_i pri[i] * exp(logpdf(cs[i], x))) # = log(sum_i exp(logpri[i] + logpdf(cs[i], x))) - # = m + log(sum_i exp(logpri[i] + logpdf(cs[i], x) - m)) - # - # m is chosen to be the maximum of logpri[i] + logpdf(cs[i], x) - # such that the argument of exp is in a reasonable range - # K = ncomponents(d) - p = probs(d) - # use Buffer to avoid mutating arrays. - # lp = Vector{eltype(p)}(undef, K) - lp = Zygote.Buffer(p, K) - m = -Inf # m <- the maximum of log(p(cs[i], x)) + log(pri[i]) - @inbounds for i in eachindex(p) - pi = p[i] - if pi > 0.0 - # lp[i] <- log(p(cs[i], x)) + log(pri[i]) - lp_i = logpdf(component(d, i), x) + log(pi) - # zygote seems to have trouble here. - # Mutating arrays is not supported - lp[i] = lp_i - if lp_i > m - m = lp_i + pri = probs(d) + + lp = map(eachindex(pri)) do i + if pri[i] > 0.0 + return logpdf(component(d, i), x) + log(pri[i]) end end - end - v = 0.0 - @inbounds for i = 1:K - if p[i] > 0.0 - v += exp(lp[i] - m) - end - end - return m + log(v) + + return logsumexp(lp) end Distributions.logpdf(d::UnivariateMixture{Continuous}, x::Real) = _mixlogpdf1(d, x) Distributions.logpdf(d::UnivariateMixture{Discrete}, x::Int) = _mixlogpdf1(d, x) - Distributions._logpdf(d::MultivariateMixture, x::AbstractVector) = _mixlogpdf1(d, x) \ No newline at end of file From 2326374b9fb7276e980ddd6581e78acf8162371c Mon Sep 17 00:00:00 2001 From: Siva Prasad Varma Date: Tue, 12 May 2020 08:48:11 -0700 Subject: [PATCH 4/4] filter out nothing from lp --- src/mixturemodels.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/mixturemodels.jl b/src/mixturemodels.jl index ca942a59..e30d8750 100644 --- a/src/mixturemodels.jl +++ b/src/mixturemodels.jl @@ -5,13 +5,10 @@ function _mixlogpdf1(d::AbstractMixtureModel, x) # = log(sum_i pri[i] * exp(logpdf(cs[i], x))) # = log(sum_i exp(logpri[i] + logpdf(cs[i], x))) - K = ncomponents(d) pri = probs(d) - - lp = map(eachindex(pri)) do i - if pri[i] > 0.0 - return logpdf(component(d, i), x) + log(pri[i]) - end + indices = findall(!iszero, pri) + lp = map(indices) do i + return logpdf(component(d, i), x) + log(pri[i]) end return logsumexp(lp)