-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Mixture Models #79
Conversation
Thanks for the PR @sivapvarma! Please implement the above function in a non-mutating way, it's possible. Then add a |
Thanks for the PR but IMO this should really be fixed in Distributions. The implementation there (which is mainly copied here) should just use |
@devmotion @mohamed82008 I see what you mean now, this can be done without Zygote.Buffer in a non-mutating way. Let me do that and remove Zygote dependency, then we can decide if this should go in Distributions or not. |
Done. No Zygote dependency anymore. But when I precompile the package, I get warnings about method overwritten. What is the solution to this ? I still have to add the tests. |
This is expected since you overwrite the definitions in Distributions. IMO the correct solution would be to fix the logpdf implementation in Distributions instead of adding a second one in DistributionsAD. |
# using the formula below for numerical stability | ||
# | ||
# logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x)) | ||
# = log(sum_i pri[i] * exp(logpdf(cs[i], x))) | ||
# = log(sum_i exp(logpri[i] + logpdf(cs[i], x))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment seems a bit useless now that logsumexp
is used instead of the manual implementation of the logsumexp trick.
# using the formula below for numerical stability | |
# | |
# logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x)) | |
# = log(sum_i pri[i] * exp(logpdf(cs[i], x))) | |
# = log(sum_i exp(logpri[i] + logpdf(cs[i], x))) |
src/mixturemodels.jl
Outdated
# = log(sum_i pri[i] * exp(logpdf(cs[i], x))) | ||
# = log(sum_i exp(logpri[i] + logpdf(cs[i], x))) | ||
|
||
K = ncomponents(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K
is not used anywhere.
K = ncomponents(d) |
src/mixturemodels.jl
Outdated
lp = map(eachindex(pri)) do i | ||
if pri[i] > 0.0 | ||
return logpdf(component(d, i), x) + log(pri[i]) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work correctly since lp[i]
will be nothing
if pri[i] <= 0.0
. Maybe better use something like
lp = map(eachindex(pri)) do i | |
if pri[i] > 0.0 | |
return logpdf(component(d, i), x) + log(pri[i]) | |
end | |
end | |
indices = findall(!iszero, pri) | |
lp = map(indices) do i | |
return logpdf(component(d, i), x) + log(pri[i]) | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally wanted to do a list comprehension like
lp = [ logpdf(component(d, i), x) + log(pri[i]) for i in eachindex(pri) if pri[i] > 0.0 ]
which does not have the nothing
's in lp
. This some how does not work with Zygote.
Then I tried to do the samething with map
and assumed the nothing
's would be filtered out. I should have tested it out in a REPL first. Thanks for catching this.
@sivapvarma we still need tests. I agree with @devmotion that if this is a strictly better implementation or at least as good as the one in Distributions.jl even without AD, then it needs to go there. |
Will be fixed upstream by JuliaStats/Distributions.jl#1308. |
The upstream PR is merged and part of Distributions 0.25.0. |
WIP to fix issue #55.
Right now I am using Zygote.Buffer which adds Zygote as a dependency, which is probably not desirable. I am not sure how to use Zygote.Buffer without adding Zygote dep.
Any comments are welcome.