Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use fill adjoint from ChainRules #202

Closed

Conversation

DhairyaLGandhi
Copy link

There might be reasons to overload it here, but the current implementation would override the generic fill adjoint, which can cause breakages. If you have an MWE, we can write a more targeted adjoint, and I 100% agree that our definitions should not assume numbers as eltypes of arrays. Presently, this causes SciML/NeuralPDE.jl#412

cc @ChrisRackauckas

@DhairyaLGandhi DhairyaLGandhi changed the title use fill from ChainRules use fill adjoint from ChainRules Oct 20, 2021
Copy link

@ChrisRackauckas ChrisRackauckas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least, this isn't where this definition should live.

@DhairyaLGandhi
Copy link
Author

Agreed, this also assumes pullback(x, dims::Int...) ... end and fails if one passes dims::Tuple. That's basically what caused the NeuralPDE breakage.

@devmotion
Copy link
Member

Haha yeah this whole package is just full of type piracy and hacks that were accumulated over time. The more we can remove or move to the packages it actually belongs to, the better.

Let's run the tests and see if it breaks something (fingers crossed that they actually pass currently, it seems recent Zygote upgrades broke tests in DynamicPPL...).

@DhairyaLGandhi
Copy link
Author

I would suggest fixing the errant adjoint anyway since it can't handle many other cases of interest. xref JuliaDiff/ChainRules.jl#537 so I suspect the tests wouldn't pass. We can of course handle it with a different rule.

@devmotion
Copy link
Member

Even though I agree it really should not be part of DistributionsAD we can't just remove it if this in turn breaks Turing or other downstream packages.

Agreed, this also assumes pullback(x, dims::Int...) ... end and fails if one passes dims::Tuple. That's basically what caused the NeuralPDE breakage.

Then I guess a temporary workaround would be to just fix the dims::Tuple case and handle it correctly in the adjoint.

@devmotion
Copy link
Member

So it seems some recent changes (last tests on master passed end of August with ReverseDiff and Tracker) in some upstream dependencies broke Tracker and ReverseDiff support. However, also Zygote errors, before tests were aborted mainly due to errors such as https://github.com/TuringLang/DistributionsAD.jl/pull/202/checks?check_run_id=3956438041#step:5:364 when testing arrays of distributions which potentially could be caused by the removal of the fill adjoint.

@ChrisRackauckas
Copy link

Then I guess a temporary workaround would be to just fix the dims::Tuple case and handle it correctly in the adjoint.

A better thing would be to narrow this dispatch. Which case is it actually fitting? Making it Any is clearly incorrect.

@devmotion
Copy link
Member

I assume it can be restricted to e.g. fill(d::Distribution, dims::Int, dims2::Int...) and possibly fill(d::Distribution, dims::Tuple{Int,Vararg{Int}}) (if we have to handle tuples) since it is used mainly to make AD work with the filldist and arraydist product distributions. This is still type piracy and should not exist here but at least better than the current implementation. It was added in #19 originally.

@mohamed82008
Copy link
Member

This adjoint wasn't exactly one of my finest works. I agree it was a horrible idea in retrospect.

@ChrisRackauckas
Copy link

I think restrict it and merge, and then upstream the fix later

@devmotion
Copy link
Member

I wonder if the adjoint is only needed in the tests due to

f_arraydist =...,) -> arraydist(fill(d.f...), n...))
and similar lines. I can't find any occurrences of fill(::Distribution, dims...) in the package, filldist always uses FillArrays.Fill and the implementation of arraydist does not use fill (unsurprisingly). So maybe we can just move the adjoint to the tests? Should still improve it a bit and only cover Distributions (if it's only needed in the tests we don't have to handle Tuples).

@ChrisRackauckas
Copy link

I think adding it to the tests is good. SciML/NeuralPDE.jl#412 is showing pretty ample evidence that this adjoint is pretty breaking downstream, so its removal is at least not bad. @DhairyaLGandhi update the PR?

@mcabbott
Copy link

f_arraydist = (θ...,) -> arraydist(fill(d.f(θ...), n...))

Is there a MWE of the problem this causes? I got lost in the tests here.

(I see that CI complains with this PR that some results are wrong, but they don't give errors.)

If the definition in ChainRules is not correct, then we should fix it, as it may cause problems we haven't thought of elsewhere.

@devmotion
Copy link
Member

There's already an issue regarding fill with non-numbers: JuliaDiff/ChainRules.jl#537

Yeah, it's a bit unclear currently what test failures are actually caused by this PR and what by changes in upstream packages such as Zygote, ChainRules, ReverseDiff, Distributions etc. since also some Tracker and ReverseDiff tests error that passed on the master branch the last time they were run. I guess I complain too much in this issue and about AD in general lately but this package and in general AD support is just a mess and immensely time consuming to maintain. Nothing changed in this package but now different things are broken 🤷

@ChrisRackauckas
Copy link

ChrisRackauckas commented Oct 21, 2021

IMO, the tests here should move to ChainRules, or it should get downstream tested (@oxinabox). These are all pretty core and it should be an issue if they are broken, and the solution shouldn't be type piracy fixes. We can slowly move to fix all of that, but for now can we at least remove the known incorrect adjoint 😅.

@DhairyaLGandhi
Copy link
Author

Moving it to tests would still define it in the tests, so it could make the DistributionsAD tests brittle indirectly. Either way, I want to see what CI says. I don't know if downstream testing is sufficient (better to do it than not, for sure). @devmotion would you mind triggering CI

@devmotion
Copy link
Member

devmotion commented Oct 21, 2021

There's only a single ChainRules definition left, everything else I already moved to Distributions and StatsFuns. Therefore I don't think ChainRules can help with running tests. The main problem are

  • the fixes and workarounds for Zygote, Tracker, and ReverseDiff which should be moved to the respective packages if they are needed and useful and
  • the alternative AD-friendlier distributions such as TuringUniform, TuringMvNormal etc. which ideally should not be needed if the originals in Distributions are made AD-friendlier (currently used to intercept calls to e.g. MvNormal or Uniform which is another source of type piracy).

@devmotion
Copy link
Member

Either way, I want to see what CI says.

It will still fail, I checked it locally some minutes ago.

@devmotion
Copy link
Member

Closed in favour of #203 which contains some additional fixes that seem sufficient for tests to pass locally.

@devmotion devmotion closed this Oct 21, 2021
@DhairyaLGandhi
Copy link
Author

Okay then the answer is to use a restricted definition. +1 to make the regular distributions AD-able.

I agree that our adjoints should not be assuming too much about what specific arguments are passed to them.

@DhairyaLGandhi DhairyaLGandhi deleted the dg/neuraladapter branch October 21, 2021 14:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants