diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl index 2586408c9..a1970bd71 100644 --- a/ext/DynamicPPLReverseDiffExt.jl +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -25,11 +25,15 @@ function LogDensityProblemsAD.ADgradient( ) end -function DynamicPPL.setmodel(f::LogDensityProblemsAD.ReverseDiffLogDensity{L,Nothing}, model::DynamicPPL.Model) where {L} +function DynamicPPL.setmodel( + f::LogDensityProblemsAD.ReverseDiffLogDensity{L,Nothing}, model::DynamicPPL.Model +) where {L} return Accessors.@set f.ℓ = setmodel(f.ℓ, model) end -function DynamicPPL.setmodel(f::LogDensityProblemsAD.ReverseDiffLogDensity{L,C}, model::DynamicPPL.Model) where {L,C} +function DynamicPPL.setmodel( + f::LogDensityProblemsAD.ReverseDiffLogDensity{L,C}, model::DynamicPPL.Model +) where {L,C} new_f = LogDensityProblemsAD.ADGradient(Val(:ReverseDiff), f.ℓ; compile=Val(true)) # TODO: without a input, can get error return Accessors.@set new_f.ℓ = setmodel(f.ℓ, model) end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 763bd9fd8..890308cc9 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -11,7 +11,6 @@ using Test, DynamicPPL, LogDensityProblems ∇ℓ = LogDensityProblems.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) @test DynamicPPL.getmodel(∇ℓ) == model @test getmodel(DynamicPPL.setmodel(∇ℓ, model)) == model - ∇ℓ = LogDensityProblems.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model) @test DynamicPPL.getmodel(new_∇ℓ) == model