Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/code-warntype
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde authored Nov 11, 2024
2 parents f6115c7 + d6e2147 commit c794ff4
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ on:
push:
branches:
- master
- backport-*
pull_request:
branches:
- master
- backport-*
merge_group:
types: [checks_requested]

Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.30.3"
version = "0.30.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -66,7 +66,7 @@ Requires = "1"
ReverseDiff = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "~1.6.6, 1.7.3"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -941,10 +941,10 @@ end
"""
float_type_with_fallback(x)
Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`.
Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`.
"""
float_type_with_fallback(::Type) = Real
float_type_with_fallback(::Type{Union{}}) = Real
float_type_with_fallback(::Type) = float(Real)
float_type_with_fallback(::Type{Union{}}) = float(Real)
float_type_with_fallback(::Type{T}) where {T<:Real} = float(T)

"""
Expand Down
15 changes: 15 additions & 0 deletions test/turing/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,19 @@
model = state_space(y, length(t))
@test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n
end

if Threads.nthreads() > 1
@testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin
@model function f(x)
ns ~ filldist(Normal(0, 2.0), 3)
m ~ Uniform(0, 1)
return x ~ Normal(m, 1)
end
model = f(1)
chain = sample(model, NUTS(), MCMCThreads(), 10, 2)
loglikelihood(model, chain)
logprior(model, chain)
logjoint(model, chain)
end
end
end

0 comments on commit c794ff4

Please sign in to comment.