Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace internal AD backend types with ADTypes #2047

Merged
merged 15 commits into from
Nov 16, 2023
Merged

Replace internal AD backend types with ADTypes #2047

merged 15 commits into from
Nov 16, 2023

Conversation

devmotion
Copy link
Member

This PR is a draft proposal for replacing our internal AD backend types with ADTypes.

Needs support of ADTypes in LogDensityProblemsAD: tpapp/LogDensityProblemsAD.jl#17

@torfjelde
Copy link
Member

Love it!

@torfjelde
Copy link
Member

We'll also need to make equivalent changes in AdvancedVI, I believe.

@yebai
Copy link
Member

yebai commented Jul 19, 2023

We'll also need to make equivalent changes in AdvancedVI, I believe.

I think @Red-Portal already did it in the AdvancedVI rewrite PR.

@Red-Portal
Copy link
Member

Hi, yes that's already up to date!

src/essential/ad.jl Outdated Show resolved Hide resolved
src/essential/ad.jl Outdated Show resolved Hide resolved
Project.toml Outdated Show resolved Hide resolved
src/essential/Essential.jl Outdated Show resolved Hide resolved
Project.toml Outdated Show resolved Hide resolved
@devmotion
Copy link
Member Author

I opened tpapp/LogDensityProblemsAD.jl#21.

Comment on lines +49 to +52
AutoForwardDiff,
AutoTracker,
AutoZygote,
AutoReverseDiff,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually want to export these types? Or should we tell users to use ADTypes.AutoForwardDiff etc. (in particular when these types would be used in other packages such as e.g. AdvancedVI as well).

src/essential/ad.jl Outdated Show resolved Hide resolved
f_rd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{false}(), f)
f_rd_compiled = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{true}(), f)
f_rd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD(false), f)
f_rd_compiled = LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), f; compile=Val(true), x=θ) # need to compile with non-zero inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I had to pass in θ for the result to be correct, started a PR tpapp/LogDensityProblemsAD.jl#22

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not add a keyword argument but just continue overloading ADgradient for TuringLogDensityFunction. I wouldn't want users to have to deal with such a keyword argument.

Project.toml Outdated Show resolved Hide resolved
) where AD
return HMC{AD}(ϵ, n_leapfrog, metricT, space)
return HMC(ϵ, n_leapfrog, metricT, space; adtype = adtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion, we can probably consider removing the global AD flag ADBACKEND, and always specify the autodiff backend in inference algorithms. What are your thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And set the default AD type for each algorithm/sampler to a specific default such as AutoForwardDiff, you mean? Or would you like users to always specify the AD type explicitly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can set the default AD type to AutoForwardDiff for all algorithms and allow users to override them via keyword arguments. That way, we don't need to maintain a global AD flag and can remove the messy code around it. But that should probably be done in a separate PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to get rid of the global flag 👍 If we make it in a separate PR we should maybe wait with tagging a breaking release until this follow-up PR is merged as well, to avoid two breaking releases in a row.

Copy link
Contributor

Pull Request Test Coverage Report for Build 6887921533

  • 0 of 41 (0.0%) changed or added relevant lines in 5 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage remained the same at 0.0%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/mcmc/Inference.jl 0 1 0.0%
ext/TuringDynamicHMCExt.jl 0 2 0.0%
src/mcmc/sghmc.jl 0 4 0.0%
src/essential/ad.jl 0 12 0.0%
src/mcmc/hmc.jl 0 22 0.0%
Files with Coverage Reduction New Missed Lines %
ext/TuringDynamicHMCExt.jl 1 0.0%
Totals Coverage Status
Change from base Build 6829598411: 0.0%
Covered Lines: 0
Relevant Lines: 1421

💛 - Coveralls

Copy link

codecov bot commented Nov 16, 2023

Codecov Report

Attention: 41 lines in your changes are missing coverage. Please review.

Comparison is base (d4a7975) 0.00% compared to head (da44611) 0.00%.

Files Patch % Lines
src/mcmc/hmc.jl 0.00% 22 Missing ⚠️
src/essential/ad.jl 0.00% 12 Missing ⚠️
src/mcmc/sghmc.jl 0.00% 4 Missing ⚠️
ext/TuringDynamicHMCExt.jl 0.00% 2 Missing ⚠️
src/mcmc/Inference.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@          Coverage Diff           @@
##           master   #2047   +/-   ##
======================================
  Coverage    0.00%   0.00%           
======================================
  Files          21      21           
  Lines        1435    1421   -14     
======================================
+ Misses       1435    1421   -14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

kwargs...
) where AD
return HMC{AD}(ϵ, n_leapfrog; kwargs...)
function HMC(ϵ::Float64, n_leapfrog::Int, ::Type{metricT}, space::Tuple; adtype::ADTypes.AbstractADType = ADBackend()) where {metricT <: AHMC.AbstractMetric}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this constructor and only support metricT as keyword argument? Or make all arguments keyword arguments?

Generally, these HMC constructors are quite messy...

Copy link
Member

@yebai yebai Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We plan to depreciate and then remove this old interface once the AbstractMCMC-based externalsampler interface works for Gibbs.

@yebai
Copy link
Member

yebai commented Nov 16, 2023

CI errors about chain resume/save are unrelated to this PR.

@yebai yebai marked this pull request as ready for review November 16, 2023 11:39
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sunxd3 @devmotion -- good work!

@yebai yebai merged commit 6649f10 into master Nov 16, 2023
11 of 13 checks passed
@yebai yebai deleted the dw/adtypes branch November 16, 2023 12:23
yebai added a commit to TuringLang/docs that referenced this pull request Dec 19, 2023
* Update autodiff.jmd following adaptation of `ADTypes`

TuringLang/Turing.jl#2047

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Using `ADTypes` for ad doc

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Hong Ge <[email protected]>

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
@storopoli
Copy link
Member

So the tape gets compiled by default now when using ReverseDiff?
No need more for Turing.setrdcache(true)?

@yebai
Copy link
Member

yebai commented Dec 20, 2023

You need to pass AutoReverseDiff(true) (true enables compiled tape, false disables it) to individual sampling algorithms, see more at https://github.com/TuringLang/Turing.jl/blob/master/HISTORY.md

@storopoli
Copy link
Member

Yes, but still is not clear. The example has only adtype=AutoForwardDiff(; chunksize)) in the sampler constructor.
I had no idea that AutoReverseDiff took positional arguments, and why AutoForwardDiff doesn't take any positional arguments?

Also, https://github.com/SciML/ADTypes.jl doesn't have docs, so it is even harder to figure it out.

@yebai
Copy link
Member

yebai commented Dec 20, 2023

@sunxd3, maybe add a few concrete examples for each popular autodiff backend to HISTORY.md and docstrings?

@sunxd3
Copy link
Member

sunxd3 commented Dec 20, 2023

Yeah, it is confusing, I'll have a PR

@storopoli
Copy link
Member

storopoli commented Dec 20, 2023

I can do that if you want to. I am hours away from my Holiday break.
At least the docs part.

EDIT: here (https://turinglang.org/v0.30/docs/using-turing/autodiff)
EDIT 2: already done here (TuringLang/docs#430) huh? Oh that's great then.

@sunxd3
Copy link
Member

sunxd3 commented Dec 20, 2023

@storopoli yep, the tutorial is updated. I just updated the release information too

@ya0 ya0 mentioned this pull request Jan 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants