-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use fill
adjoint from ChainRules
#202
Conversation
fill
from ChainRulesfill
adjoint from ChainRules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least, this isn't where this definition should live.
Agreed, this also assumes |
Haha yeah this whole package is just full of type piracy and hacks that were accumulated over time. The more we can remove or move to the packages it actually belongs to, the better. Let's run the tests and see if it breaks something (fingers crossed that they actually pass currently, it seems recent Zygote upgrades broke tests in DynamicPPL...). |
I would suggest fixing the errant adjoint anyway since it can't handle many other cases of interest. xref JuliaDiff/ChainRules.jl#537 so I suspect the tests wouldn't pass. We can of course handle it with a different rule. |
Even though I agree it really should not be part of DistributionsAD we can't just remove it if this in turn breaks Turing or other downstream packages.
Then I guess a temporary workaround would be to just fix the |
So it seems some recent changes (last tests on master passed end of August with ReverseDiff and Tracker) in some upstream dependencies broke Tracker and ReverseDiff support. However, also Zygote errors, before tests were aborted mainly due to errors such as https://github.com/TuringLang/DistributionsAD.jl/pull/202/checks?check_run_id=3956438041#step:5:364 when testing arrays of distributions which potentially could be caused by the removal of the |
A better thing would be to narrow this dispatch. Which case is it actually fitting? Making it |
I assume it can be restricted to e.g. |
This adjoint wasn't exactly one of my finest works. I agree it was a horrible idea in retrospect. |
I think restrict it and merge, and then upstream the fix later |
I wonder if the adjoint is only needed in the tests due to DistributionsAD.jl/test/ad/distributions.jl Line 477 in 44a57e9
fill(::Distribution, dims...) in the package, filldist always uses FillArrays.Fill and the implementation of arraydist does not use fill (unsurprisingly). So maybe we can just move the adjoint to the tests? Should still improve it a bit and only cover Distribution s (if it's only needed in the tests we don't have to handle Tuple s).
|
I think adding it to the tests is good. SciML/NeuralPDE.jl#412 is showing pretty ample evidence that this adjoint is pretty breaking downstream, so its removal is at least not bad. @DhairyaLGandhi update the PR? |
Is there a MWE of the problem this causes? I got lost in the tests here. (I see that CI complains with this PR that some results are wrong, but they don't give errors.) If the definition in ChainRules is not correct, then we should fix it, as it may cause problems we haven't thought of elsewhere. |
There's already an issue regarding Yeah, it's a bit unclear currently what test failures are actually caused by this PR and what by changes in upstream packages such as Zygote, ChainRules, ReverseDiff, Distributions etc. since also some Tracker and ReverseDiff tests error that passed on the master branch the last time they were run. I guess I complain too much in this issue and about AD in general lately but this package and in general AD support is just a mess and immensely time consuming to maintain. Nothing changed in this package but now different things are broken 🤷 |
IMO, the tests here should move to ChainRules, or it should get downstream tested (@oxinabox). These are all pretty core and it should be an issue if they are broken, and the solution shouldn't be type piracy fixes. We can slowly move to fix all of that, but for now can we at least remove the known incorrect adjoint 😅. |
Moving it to tests would still define it in the tests, so it could make the DistributionsAD tests brittle indirectly. Either way, I want to see what CI says. I don't know if downstream testing is sufficient (better to do it than not, for sure). @devmotion would you mind triggering CI |
There's only a single ChainRules definition left, everything else I already moved to Distributions and StatsFuns. Therefore I don't think ChainRules can help with running tests. The main problem are
|
It will still fail, I checked it locally some minutes ago. |
Closed in favour of #203 which contains some additional fixes that seem sufficient for tests to pass locally. |
Okay then the answer is to use a restricted definition. +1 to make the regular distributions AD-able. I agree that our adjoints should not be assuming too much about what specific arguments are passed to them. |
There might be reasons to overload it here, but the current implementation would override the generic
fill
adjoint, which can cause breakages. If you have an MWE, we can write a more targeted adjoint, and I 100% agree that our definitions should not assume numbers as eltypes of arrays. Presently, this causes SciML/NeuralPDE.jl#412cc @ChrisRackauckas