-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Check that the correct AD backend is being used #2291
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
b9da0a6
Add ADTypeCheckContext
mhauru 17eb8a9
Check ADType use in optimisation
mhauru c0cab10
Use ADTypeCheckContext with hmc tests
mhauru 44dd017
using A: A instead of import A
mhauru 15701c2
More robust ADTypeCheckContext checks for Zygote
mhauru File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
module ADUtils | ||
|
||
using ForwardDiff: ForwardDiff | ||
using ReverseDiff: ReverseDiff | ||
using Test: Test | ||
using Tracker: Tracker | ||
using Turing: Turing | ||
using Turing: DynamicPPL | ||
using Zygote: Zygote | ||
|
||
export ADTypeCheckContext | ||
|
||
"""Element types that are always valid for a VarInfo regardless of ADType.""" | ||
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) | ||
|
||
"""A dictionary mapping ADTypes to the element types they use.""" | ||
const eltypes_by_adtype = Dict( | ||
Turing.AutoForwardDiff => (ForwardDiff.Dual,), | ||
Turing.AutoReverseDiff => ( | ||
ReverseDiff.TrackedArray, | ||
ReverseDiff.TrackedMatrix, | ||
ReverseDiff.TrackedReal, | ||
ReverseDiff.TrackedStyle, | ||
ReverseDiff.TrackedType, | ||
ReverseDiff.TrackedVecOrMat, | ||
ReverseDiff.TrackedVector, | ||
), | ||
# Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the | ||
# two by element type. However, we have other checks for Zygote, see check_adtype. | ||
Turing.AutoZygote => (Zygote.Dual,), | ||
Turing.AutoTracker => ( | ||
Tracker.Tracked, | ||
Tracker.TrackedArray, | ||
Tracker.TrackedMatrix, | ||
Tracker.TrackedReal, | ||
Tracker.TrackedStyle, | ||
Tracker.TrackedVecOrMat, | ||
Tracker.TrackedVector, | ||
), | ||
) | ||
|
||
""" | ||
AbstractWrongADBackendError | ||
|
||
An abstract error thrown when we seem to be using a different AD backend than expected. | ||
""" | ||
abstract type AbstractWrongADBackendError <: Exception end | ||
|
||
""" | ||
WrongADBackendError | ||
|
||
An error thrown when we seem to be using a different AD backend than expected. | ||
""" | ||
struct WrongADBackendError <: AbstractWrongADBackendError | ||
actual_adtype::Type | ||
expected_adtype::Type | ||
end | ||
|
||
function Base.showerror(io::IO, e::WrongADBackendError) | ||
return print( | ||
io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead." | ||
) | ||
end | ||
|
||
""" | ||
IncompatibleADTypeError | ||
|
||
An error thrown when an element type is encountered that is unexpected for the given ADType. | ||
""" | ||
struct IncompatibleADTypeError <: AbstractWrongADBackendError | ||
valtype::Type | ||
adtype::Type | ||
end | ||
|
||
function Base.showerror(io::IO, e::IncompatibleADTypeError) | ||
return print( | ||
io, | ||
"Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)", | ||
) | ||
end | ||
|
||
""" | ||
ADTypeCheckContext{ADType,ChildContext} | ||
|
||
A context for checking that the expected ADType is being used. | ||
|
||
Evaluating a model with this context will check that the types of values in a `VarInfo` are | ||
compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError` | ||
is thrown. | ||
|
||
For instance, evaluating a model with | ||
`ADTypeCheckContext(AutoForwardDiff(), child_context)` | ||
would throw an error if within the model a type associated with e.g. ReverseDiff was | ||
encountered. | ||
|
||
As a current short-coming, this context can not distinguish between ForwardDiff and Zygote. | ||
""" | ||
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: | ||
DynamicPPL.AbstractContext | ||
child::ChildContext | ||
|
||
function ADTypeCheckContext(adbackend, child) | ||
adtype = adbackend isa Type ? adbackend : typeof(adbackend) | ||
if !any(adtype <: k for k in keys(eltypes_by_adtype)) | ||
throw(ArgumentError("Unsupported ADType: $adtype")) | ||
end | ||
return new{adtype,typeof(child)}(child) | ||
end | ||
end | ||
|
||
adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType | ||
|
||
DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent() | ||
DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child | ||
function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child) | ||
return ADTypeCheckContext(adtype(c), child) | ||
end | ||
|
||
""" | ||
valid_eltypes(context::ADTypeCheckContext) | ||
|
||
Return the element types that are valid for the ADType of `context` as a tuple. | ||
""" | ||
function valid_eltypes(context::ADTypeCheckContext) | ||
context_at = adtype(context) | ||
for at in keys(eltypes_by_adtype) | ||
if context_at <: at | ||
return (eltypes_by_adtype[at]..., always_valid_eltypes...) | ||
end | ||
end | ||
# This should never be reached due to the check in the inner constructor. | ||
throw(ArgumentError("Unsupported ADType: $(adtype(context))")) | ||
end | ||
|
||
""" | ||
check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo) | ||
|
||
Check that the element types in `vi` are compatible with the ADType of `context`. | ||
|
||
When Zygote is being used, we also more explicitly check that `adtype(context)` is | ||
`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't | ||
discriminate between the two based on element type alone. This function will still fail to | ||
catch cases where Zygote is supposed to be used, but ForwardDiff is used instead. | ||
|
||
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or | ||
`WrongADBackendError` if Zygote is used unexpectedly. | ||
""" | ||
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) | ||
Zygote.hook(vi) do _ | ||
if !(adtype(context) <: Turing.AutoZygote) | ||
throw(WrongADBackendError(Turing.AutoZygote, adtype(context))) | ||
end | ||
end | ||
|
||
valids = valid_eltypes(context) | ||
for val in vi[:] | ||
valtype = typeof(val) | ||
if !any(valtype .<: valids) | ||
throw(IncompatibleADTypeError(valtype, adtype(context))) | ||
end | ||
end | ||
return nothing | ||
end | ||
|
||
# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child | ||
# context, and then call check_adtype on the result before returning the results from the | ||
# child context. | ||
|
||
function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) | ||
value, logp, vi = DynamicPPL.tilde_assume( | ||
DynamicPPL.childcontext(context), right, vn, vi | ||
) | ||
check_adtype(context, vi) | ||
return value, logp, vi | ||
end | ||
|
||
function DynamicPPL.tilde_assume(rng, context::ADTypeCheckContext, sampler, right, vn, vi) | ||
value, logp, vi = DynamicPPL.tilde_assume( | ||
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi | ||
) | ||
check_adtype(context, vi) | ||
return value, logp, vi | ||
end | ||
|
||
function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) | ||
logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) | ||
check_adtype(context, vi) | ||
return logp, vi | ||
end | ||
|
||
function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) | ||
logp, vi = DynamicPPL.tilde_observe( | ||
DynamicPPL.childcontext(context), sampler, right, left, vi | ||
) | ||
check_adtype(context, vi) | ||
return logp, vi | ||
end | ||
|
||
function DynamicPPL.dot_tilde_assume(context::ADTypeCheckContext, right, left, vn, vi) | ||
value, logp, vi = DynamicPPL.dot_tilde_assume( | ||
DynamicPPL.childcontext(context), right, left, vn, vi | ||
) | ||
check_adtype(context, vi) | ||
return value, logp, vi | ||
end | ||
|
||
function DynamicPPL.dot_tilde_assume( | ||
rng, context::ADTypeCheckContext, sampler, right, left, vn, vi | ||
) | ||
value, logp, vi = DynamicPPL.dot_tilde_assume( | ||
rng, DynamicPPL.childcontext(context), sampler, right, left, vn, vi | ||
) | ||
check_adtype(context, vi) | ||
return value, logp, vi | ||
end | ||
|
||
function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, right, left, vi) | ||
logp, vi = DynamicPPL.dot_tilde_observe( | ||
DynamicPPL.childcontext(context), right, left, vi | ||
) | ||
check_adtype(context, vi) | ||
return logp, vi | ||
end | ||
|
||
function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) | ||
logp, vi = DynamicPPL.dot_tilde_observe( | ||
DynamicPPL.childcontext(context), sampler, right, left, vi | ||
) | ||
check_adtype(context, vi) | ||
return logp, vi | ||
end | ||
|
||
# Check that the ADTypeCheckContext works as expected. | ||
Test.@testset "ADTypeCheckContext" begin | ||
Turing.@model test_model() = x ~ Turing.Normal(0, 1) | ||
tm = test_model() | ||
adtypes = ( | ||
Turing.AutoForwardDiff(), | ||
Turing.AutoReverseDiff(), | ||
Turing.AutoZygote(), | ||
Turing.AutoTracker(), | ||
) | ||
for actual_adtype in adtypes | ||
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) | ||
for expected_adtype in adtypes | ||
if ( | ||
actual_adtype == Turing.AutoForwardDiff() && | ||
expected_adtype == Turing.AutoZygote() | ||
) | ||
# TODO(mhauru) We are currently unable to check this case. | ||
continue | ||
end | ||
contextualised_tm = DynamicPPL.contextualize( | ||
tm, ADTypeCheckContext(expected_adtype, tm.context) | ||
) | ||
Test.@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin | ||
if actual_adtype == expected_adtype | ||
# Check that this does not throw an error. | ||
Turing.sample(contextualised_tm, sampler, 2) | ||
else | ||
Test.@test_throws AbstractWrongADBackendError Turing.sample( | ||
contextualised_tm, sampler, 2 | ||
) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah there's no way you can do this here unfortunately 😕
Possibly crazy idea, but it might make sense to specifically overload the adjoint computation for a given backend which doesn't use types, e.g.
? Could do the same with Enzyme.jl and Tapir.jl, both of which would suffer from the same issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this, and implemented it. Turns out there's
Zygote.hook
for exactly these sorts of things.I thought a bit about switching to doing all checks like this, and giving up on the element type approach, but that seemed like it would get more complicated than is worth at this point. We now catch all cases except if one uses ForwardDiff when Zygote is expected, which seems sufficient to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very, very nice @mhauru :)