Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uncompatibility rand(MvNormal()) and AutoDiff #813

Closed
theogf opened this issue Jan 15, 2019 · 16 comments
Closed

Uncompatibility rand(MvNormal()) and AutoDiff #813

theogf opened this issue Jan 15, 2019 · 16 comments
Labels

Comments

@theogf
Copy link
Contributor

theogf commented Jan 15, 2019

Hello it is unfortunately not possible to use automatic differentation with (at least) the MvNormal distribution.
The following code will fail at rand(p) due to a wrong conversion to Float64

using Distributions, ForwardDiff
function foo(mu)
    p = MvNormal(mu)
    sum(rand(p) for _ in 1:100)
end
ForwardDiff.gradient(foo,rand(10))
@matbesancon
Copy link
Member

@theogf PR welcome on that

@matbesancon matbesancon added the bug label May 8, 2019
@theogf
Copy link
Contributor Author

theogf commented May 8, 2019

So I found the source of the error but it's in common.jl so not sure if the change would be breaking :

Base.eltype(s::Sampleable{F,Continuous}) where {F} = Float64

Says that whatever the type of the distribution is eltype will return Float64 for a continuous function. Since eltype is called many times for the MvNormal sampling, one always end up with Float64 samples. I don't know the dependency of other distributions on eltype but a quick fix would be overloading it for MvNormal types

@matbesancon
Copy link
Member

matbesancon commented May 8, 2019 via email

@matbesancon
Copy link
Member

@theogf did #882 close this? The error I get is because the function does not return a scalar, which ForwardDiff.gradient requires. It seems fixed with:

julia> function foo(mu)
           p = MvNormal(mu)
           sum(sum(rand(p) for _ in 1:100))
       end

@andreasnoack
Copy link
Member

Is it really reasonable to expect rand to be differentiable? Making the sampler structs parametric introduces a lot of complications since the samplers are usually written for a specific precision. It also touches on the issue of the variate type vs the parameter type. For most distributions, the variate type doesn't follow the parameter type.

@mschauer
Copy link
Member

mschauer commented Jul 7, 2021

Yes, I think this could be important, differentiability of a + b*randn(rng) conditional on the rng state is practical (keyword normalising flows)

@devmotion
Copy link
Member

Related: TuringLang/DistributionsAD.jl#123

I think in many cases it is not important to implement the samplers in a differentiable way but it would be useful to add custom adjoints, probably based on ChainRulesCore.

@andreasnoack
Copy link
Member

andreasnoack commented Jul 7, 2021

But does this generalize beyond location/scale? I don't think e.g. GammaGDSampler is differentiable in the shape parameter. Notice that #1024 is a consequence of the complexity that a parametric struct introduces.

@mschauer
Copy link
Member

mschauer commented Jul 8, 2021

Every y = model(rng, x) where you see y and want to know x awakes the desire to take x derivatives of model, that is not restricted to means and Gaussianity. So if it's possible (implying differentiable dependence on x for fix Random.seed!) you would like to allow it.

@andreasnoack
Copy link
Member

andreasnoack commented Jul 8, 2021

It would be great if we could figure out a way to handle this that isn't generally broken and only sometimes (even if very often) works. Recall that the GammaGDSampler is currently broken for some inputs because of the type parameters and it didn't take too much effort to find an example where the variates aren't continuous in the shape parameter

julia> _seed
165

julia> tmp = [rand(Random.MersenneTwister(_seed), Distributions.GammaGDSampler(Gamma(_s, 0.1))) for _s in s];

gammasampler

I generally think you should be very reluctant to allow loose signatures for methods that exploits details about floating point numbers such as their precision.

I'm wondering if instead, we could define AD rules for rand methods. They could (hopefully?) be restricted to the parameters for which we know the variates are differentiable given a seed. For the Gamma it might actually be a problem the distribution only has a single type parameter since it excludes the possibility of restricting Duals to the scale parameter.

@devmotion
Copy link
Member

I'm wondering if instead, we could define AD rules for rand methods.

This is what I had in mind when I said that it might be better to add custom adjoints/AD rules instead of trying to make the sampling algorithm itself differentiable.

For the Gamma it might actually be a problem the distribution only has a single type parameter since it excludes the possibility of restricting Duals to the scale parameter.

This would only be a problem for AD systems that operate with special number types such as Dual but not for e.g. Zygote.

@mschauer
Copy link
Member

mschauer commented Jul 8, 2021

That single type parameter shouldn’t pose a problem, one can promote the other parameter to dual too.

@andreasnoack
Copy link
Member

That single type parameter shouldn’t pose a problem, one can promote the other parameter to dual too.

The point I tried to make was that you'd have to restrict the argument type for the shape parameter to Float64 and not allow Duals.

@mschauer
Copy link
Member

mschauer commented Jul 8, 2021

Its not differentiable in the shape but it can run on duals
with partials equal 0? I don’t think it is responsibility of a package to prevent dual inputs to nondifferentable functions

@andreasnoack
Copy link
Member

andreasnoack commented Jul 9, 2021

Let me elaborate: roughly speaking, there are two kinds of floating point methods

  1. "Core", such as (+)(::Float64, Float64), exp(::Float32), and exp(::Float64).
  2. "Compositions", such as logistic(::Real), norm(::Vector{<:Number}) and most user defined methods

The first groups exploits details of the computer representation of the numbers such as calling an LLVM intrinsic or a libm function with a specific precision. However, the group also includes native Julia implementation that evaluates series to a specific precision. So many of the definitions in SpecialFunctions fall into this group. I'm arguing that some/most samplers fall are part of this group as well. The GammaGDSampler is such an example, see e.g.

if a <= 3.686
b = 0.463 + s + 0.178s2
σ = 1.235
c = 0.195/s - 0.079 + 0.16s
elseif a <= 13.022
b = 1.654 + 0.0076s2
σ = 1.68/s + 0.275
c = 0.062/s + 0.024
else
b = 1.77
σ = 0.75
c = 0.1515/s
end
and
return s.q0 + 0.5*t*t*(v*@horner(v,
0.333333333,
-0.249999949,
0.199999867,
-0.1666774828,
0.142873973,
-0.124385581,
0.110368310,
-0.112750886,
0.10408986))
. My argument is that whenever methods evaluate to a specific precision like that then we'd have to restrict the signature to ensure correctness.

The second group is composed out of "Core" methods and I completely agree that such definitions should have as loose a signature as possible to allow for as many number types as possible. Regarding AD then we need rules for the "Core" group for AD to work and the beauty is then that AD automatically works for the second group provided that we have used sufficiently loose signatures in the method definitions.

What we are currently doing is that we consider the sampler a "Composition" method. I'm arguing that it's not sound and that we'd have to make it a "Core" method and define AD rules for it. Specifically, we only need to consider the version for scale==1 and "Core" method. The scaling can be handled as a "Composition" which is why I said that it might be better to split the type parameter in two.

@mschauer
Copy link
Member

mschauer commented Jul 9, 2021

I think we are on one page, I agree with your line of reasoning, only in isolation I wouldn't restrict

julia> f(v) = (v*@horner(v, 
                                         0.333333333, 
                                         -0.249999949, 
                                         0.199999867, 
                                         -0.1666774828, 
                                         0.142873973, 
                                         -0.124385581, 
                                         0.110368310, 
                                         -0.112750886, 
                                         0.10408986))

to Float64 because it does exactly the right thing for e.g.

julia> using IntervalArithmetic

julia> f(1.0..(1+eps()))
[0.236851, 0.236852]

at the right precision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants