Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Mixture Models #79

Closed
wants to merge 6 commits into from
Closed

Conversation

sivapvarma
Copy link

@sivapvarma sivapvarma commented May 11, 2020

WIP to fix issue #55.

Right now I am using Zygote.Buffer which adds Zygote as a dependency, which is probably not desirable. I am not sure how to use Zygote.Buffer without adding Zygote dep.

Any comments are welcome.

src/mixturemodels.jl Outdated Show resolved Hide resolved
@mohamed82008
Copy link
Member

Thanks for the PR @sivapvarma! Please implement the above function in a non-mutating way, it's possible. Then add a MixtureModel to the tests.

@devmotion
Copy link
Member

Thanks for the PR but IMO this should really be fixed in Distributions. The implementation there (which is mainly copied here) should just use map and logsumexp, and then probably it works with AD. I think we don't want to use Zygote.Buffer and depend on Zygote in DistributionsAD.

@sivapvarma
Copy link
Author

@devmotion @mohamed82008 I see what you mean now, this can be done without Zygote.Buffer in a non-mutating way. Let me do that and remove Zygote dependency, then we can decide if this should go in Distributions or not.

@sivapvarma
Copy link
Author

sivapvarma commented May 11, 2020

Done. No Zygote dependency anymore. But when I precompile the package, I get warnings about method overwritten. What is the solution to this ?

I still have to add the tests.

@sivapvarma sivapvarma requested a review from mohamed82008 May 11, 2020 23:04
@devmotion
Copy link
Member

But when I precompile the package, I get warnings about method overwritten. What is the solution to this ?

This is expected since you overwrite the definitions in Distributions. IMO the correct solution would be to fix the logpdf implementation in Distributions instead of adding a second one in DistributionsAD.

Comment on lines +2 to +7
# using the formula below for numerical stability
#
# logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x))
# = log(sum_i pri[i] * exp(logpdf(cs[i], x)))
# = log(sum_i exp(logpri[i] + logpdf(cs[i], x)))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment seems a bit useless now that logsumexp is used instead of the manual implementation of the logsumexp trick.

Suggested change
# using the formula below for numerical stability
#
# logpdf(d, x) = log(sum_i pri[i] * pdf(cs[i], x))
# = log(sum_i pri[i] * exp(logpdf(cs[i], x)))
# = log(sum_i exp(logpri[i] + logpdf(cs[i], x)))

# = log(sum_i pri[i] * exp(logpdf(cs[i], x)))
# = log(sum_i exp(logpri[i] + logpdf(cs[i], x)))

K = ncomponents(d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K is not used anywhere.

Suggested change
K = ncomponents(d)

Comment on lines 11 to 15
lp = map(eachindex(pri)) do i
if pri[i] > 0.0
return logpdf(component(d, i), x) + log(pri[i])
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work correctly since lp[i] will be nothing if pri[i] <= 0.0. Maybe better use something like

Suggested change
lp = map(eachindex(pri)) do i
if pri[i] > 0.0
return logpdf(component(d, i), x) + log(pri[i])
end
end
indices = findall(!iszero, pri)
lp = map(indices) do i
return logpdf(component(d, i), x) + log(pri[i])
end

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally wanted to do a list comprehension like

lp = [ logpdf(component(d, i), x) + log(pri[i]) for i in eachindex(pri) if pri[i] > 0.0 ]

which does not have the nothing's in lp. This some how does not work with Zygote.

Then I tried to do the samething with map and assumed the nothing's would be filtered out. I should have tested it out in a REPL first. Thanks for catching this.

@mohamed82008
Copy link
Member

@sivapvarma we still need tests. I agree with @devmotion that if this is a strictly better implementation or at least as good as the one in Distributions.jl even without AD, then it needs to go there.

@devmotion
Copy link
Member

Will be fixed upstream by JuliaStats/Distributions.jl#1308.

@devmotion
Copy link
Member

The upstream PR is merged and part of Distributions 0.25.0.

@devmotion devmotion closed this May 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants