Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check that the correct AD backend is being used #2291

Merged
merged 5 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module HMCTests

using ..Models: gdemo_default
using ..ADUtils: ADTypeCheckContext
#using ..Models: gdemo
using ..NumericalTests: check_gdemo, check_numerical
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
Expand Down Expand Up @@ -321,6 +322,15 @@ using Turing
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
end

@testset "Check ADType" begin
alg = HMC(0.1, 10; adtype=adbackend)
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
sample(rng, m, alg, 10)
end
end

end
14 changes: 13 additions & 1 deletion test/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module OptimisationTests

using ..Models: gdemo, gdemo_default
using ..ADUtils: ADTypeCheckContext
using Distributions
using Distributions.FillArrays: Zeros
using DynamicPPL: DynamicPPL
Expand Down Expand Up @@ -140,7 +141,6 @@ using Turing
gdemo_default, OptimizationOptimJL.LBFGS(); initial_params=true_value
)
m3 = maximum_likelihood(gdemo_default, OptimizationOptimJL.Newton())
# TODO(mhauru) How can we check that the adtype is actually AutoReverseDiff?
m4 = maximum_likelihood(
gdemo_default, OptimizationOptimJL.BFGS(); adtype=AutoReverseDiff()
)
Expand Down Expand Up @@ -616,6 +616,18 @@ using Turing
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
@assert get(result, :c) == (; :c => Array{Float64}[])
end

@testset "ADType" begin
Random.seed!(222)
for adbackend in (AutoReverseDiff(), AutoForwardDiff(), AutoTracker())
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
maximum_likelihood(m; adtype=adbackend)
maximum_a_posteriori(m; adtype=adbackend)
end
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Turing

include(pkgdir(Turing) * "/test/test_utils/models.jl")
include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl")
include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl")

Turing.setprogress!(false)

Expand Down
270 changes: 270 additions & 0 deletions test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
module ADUtils

using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff
using Test: Test
using Tracker: Tracker
using Turing: Turing
using Turing: DynamicPPL
using Zygote: Zygote

export ADTypeCheckContext

"""Element types that are always valid for a VarInfo regardless of ADType."""
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)

"""A dictionary mapping ADTypes to the element types they use."""
const eltypes_by_adtype = Dict(
Turing.AutoForwardDiff => (ForwardDiff.Dual,),
Turing.AutoReverseDiff => (
ReverseDiff.TrackedArray,
ReverseDiff.TrackedMatrix,
ReverseDiff.TrackedReal,
ReverseDiff.TrackedStyle,
ReverseDiff.TrackedType,
ReverseDiff.TrackedVecOrMat,
ReverseDiff.TrackedVector,
),
# Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the
# two by element type. However, we have other checks for Zygote, see check_adtype.
Turing.AutoZygote => (Zygote.Dual,),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah there's no way you can do this here unfortunately 😕

Possibly crazy idea, but it might make sense to specifically overload the adjoint computation for a given backend which doesn't use types, e.g.

Zygote.@adjoint function check_adtype(...)
    # Should only be hit if we're using Zygote.jl.
    ...
end

? Could do the same with Enzyme.jl and Tapir.jl, both of which would suffer from the same issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, and implemented it. Turns out there's Zygote.hook for exactly these sorts of things.

I thought a bit about switching to doing all checks like this, and giving up on the element type approach, but that seemed like it would get more complicated than is worth at this point. We now catch all cases except if one uses ForwardDiff when Zygote is expected, which seems sufficient to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very, very nice @mhauru :)

Turing.AutoTracker => (
Tracker.Tracked,
Tracker.TrackedArray,
Tracker.TrackedMatrix,
Tracker.TrackedReal,
Tracker.TrackedStyle,
Tracker.TrackedVecOrMat,
Tracker.TrackedVector,
),
)

"""
AbstractWrongADBackendError

An abstract error thrown when we seem to be using a different AD backend than expected.
"""
abstract type AbstractWrongADBackendError <: Exception end

"""
WrongADBackendError

An error thrown when we seem to be using a different AD backend than expected.
"""
struct WrongADBackendError <: AbstractWrongADBackendError
actual_adtype::Type
expected_adtype::Type
end

function Base.showerror(io::IO, e::WrongADBackendError)
return print(
io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead."
)
end

"""
IncompatibleADTypeError

An error thrown when an element type is encountered that is unexpected for the given ADType.
"""
struct IncompatibleADTypeError <: AbstractWrongADBackendError
valtype::Type
adtype::Type
end

function Base.showerror(io::IO, e::IncompatibleADTypeError)
return print(
io,
"Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)",
)
end

"""
ADTypeCheckContext{ADType,ChildContext}

A context for checking that the expected ADType is being used.

Evaluating a model with this context will check that the types of values in a `VarInfo` are
compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError`
is thrown.

For instance, evaluating a model with
`ADTypeCheckContext(AutoForwardDiff(), child_context)`
would throw an error if within the model a type associated with e.g. ReverseDiff was
encountered.

As a current short-coming, this context can not distinguish between ForwardDiff and Zygote.
"""
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
DynamicPPL.AbstractContext
child::ChildContext

function ADTypeCheckContext(adbackend, child)
adtype = adbackend isa Type ? adbackend : typeof(adbackend)
if !any(adtype <: k for k in keys(eltypes_by_adtype))
throw(ArgumentError("Unsupported ADType: $adtype"))
end
return new{adtype,typeof(child)}(child)
end
end

adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType

DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child
function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child)
return ADTypeCheckContext(adtype(c), child)
end

"""
valid_eltypes(context::ADTypeCheckContext)

Return the element types that are valid for the ADType of `context` as a tuple.
"""
function valid_eltypes(context::ADTypeCheckContext)
context_at = adtype(context)
for at in keys(eltypes_by_adtype)
if context_at <: at
return (eltypes_by_adtype[at]..., always_valid_eltypes...)
end
end
# This should never be reached due to the check in the inner constructor.
throw(ArgumentError("Unsupported ADType: $(adtype(context))"))
end

"""
check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo)

Check that the element types in `vi` are compatible with the ADType of `context`.

When Zygote is being used, we also more explicitly check that `adtype(context)` is
`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't
discriminate between the two based on element type alone. This function will still fail to
catch cases where Zygote is supposed to be used, but ForwardDiff is used instead.

Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or
`WrongADBackendError` if Zygote is used unexpectedly.
"""
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
Zygote.hook(vi) do _
if !(adtype(context) <: Turing.AutoZygote)
throw(WrongADBackendError(Turing.AutoZygote, adtype(context)))
end
end

valids = valid_eltypes(context)
for val in vi[:]
valtype = typeof(val)
if !any(valtype .<: valids)
throw(IncompatibleADTypeError(valtype, adtype(context)))
end
end
return nothing
end

# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
# context, and then call check_adtype on the result before returning the results from the
# child context.

function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi)
value, logp, vi = DynamicPPL.tilde_assume(
DynamicPPL.childcontext(context), right, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.tilde_assume(rng, context::ADTypeCheckContext, sampler, right, vn, vi)
value, logp, vi = DynamicPPL.tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi)
logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
logp, vi = DynamicPPL.tilde_observe(
DynamicPPL.childcontext(context), sampler, right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.dot_tilde_assume(context::ADTypeCheckContext, right, left, vn, vi)
value, logp, vi = DynamicPPL.dot_tilde_assume(
DynamicPPL.childcontext(context), right, left, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.dot_tilde_assume(
rng, context::ADTypeCheckContext, sampler, right, left, vn, vi
)
value, logp, vi = DynamicPPL.dot_tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, left, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, right, left, vi)
logp, vi = DynamicPPL.dot_tilde_observe(
DynamicPPL.childcontext(context), right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
logp, vi = DynamicPPL.dot_tilde_observe(
DynamicPPL.childcontext(context), sampler, right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

# Check that the ADTypeCheckContext works as expected.
Test.@testset "ADTypeCheckContext" begin
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
tm = test_model()
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
Turing.AutoZygote(),
Turing.AutoTracker(),
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
if (
actual_adtype == Turing.AutoForwardDiff() &&
expected_adtype == Turing.AutoZygote()
)
# TODO(mhauru) We are currently unable to check this case.
continue
end
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
Test.@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
if actual_adtype == expected_adtype
# Check that this does not throw an error.
Turing.sample(contextualised_tm, sampler, 2)
else
Test.@test_throws AbstractWrongADBackendError Turing.sample(
contextualised_tm, sampler, 2
)
end
end
end
end
end

end
Loading