Skip to content

Commit

Permalink
Mark Sampling context as not needing derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 19, 2023
1 parent a1d1b35 commit a27a73b
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 6 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -25,9 +24,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMCMCEnzymeCoreExt = ["EnzymeCore"]

[compat]
AbstractMCMC = "5"
Expand All @@ -54,3 +55,4 @@ julia = "1.6"

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3 changes: 3 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ end
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
end
end

Expand Down
3 changes: 0 additions & 3 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte
context::C
end

using EnzymeCore
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true

function SamplingContext(
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand Down
7 changes: 5 additions & 2 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ using DynamicPPL:
hasconditioned_nested,
getconditioned_nested

using EnzymeCore

# Dummy context to test nested behaviors.
struct ParentContext{C<:AbstractContext} <: AbstractContext
context::C
Expand Down Expand Up @@ -76,8 +78,8 @@ end

@testset "NodeTrait" begin
@testset "$context" for context in contexts
# Every `context` should have a `NodeTrait`.
@test NodeTrait(context) isa NodeTrait
# Every `context` type should not be differentiated.
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end
end

Expand Down Expand Up @@ -252,6 +254,7 @@ end
@test SamplingContext(Random.default_rng(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end

@testset "FixedContext" begin
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include("test_util.jl")
if GROUP == "All" || GROUP == "DynamicPPL"
@testset "interface" begin
include("utils.jl")
include("enzyme.jl")
include("compiler.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
Expand Down

0 comments on commit a27a73b

Please sign in to comment.