From bf5aeb4d906d8dcbba9a677cbc0e6677addee6fe Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 7 Jan 2025 11:14:29 +0000 Subject: [PATCH 1/7] improve error message for `initial_params` --- src/sampler.jl | 35 ++++++++++++++++++++++++++++++----- test/sampler.jl | 25 +++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 40418114e..34671c928 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -157,14 +157,22 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). initialsampler(spl::Sampler) = SampleFromPrior() function set_values!!( - varinfo::AbstractVarInfo, - initial_params::AbstractVector{<:Union{Real,Missing}}, - spl::AbstractSampler, -) + varinfo::AbstractVarInfo, initial_params::AbstractVector{T}, spl::AbstractSampler +) where {T} + if T === Any + throw( + ArgumentError( + "`initial_params` must be a vector of type `Union{Real,Missing}`. " * + "If `initial_params` is a vector of vectors, please flatten it first using `vcat`.", + ), + ) + end + flattened_param_vals = varinfo[spl] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))", + "Provided initial value size ($(length(initial_params))) doesn't match " * + "the model size ($(length(flattened_param_vals))).", ), ) @@ -183,6 +191,23 @@ end function set_values!!( varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler ) + vars_in_varinfo = keys(varinfo) + for v in keys(initial_params) + if !(v in vars_in_varinfo) + for vv in vars_in_varinfo + if subsumes(VarName{v}(), vv) + throw( + ArgumentError( + "Variable $v not found in model, but it subsumes a variable ($vv) in the model. " * + "Please use AbstractVector for initial_params instead of NamedTuple.", + ), + ) + end + end + + throw(ArgumentError("Variable $v not found in the model.")) + end + end initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) return update_values!!( varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) diff --git a/test/sampler.jl b/test/sampler.jl index e5fe6dc98..3b5424671 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -178,5 +178,30 @@ @test c1[1].metadata.s.vals == c2[1].metadata.s.vals end end + + @testset "error handling" begin + # https://github.com/TuringLang/Turing.jl/issues/2452 + @model function constrained_uniform(n) + Z ~ Uniform(10, 20) + X = Vector{Float64}(undef, n) + for i in 1:n + X[i] ~ Uniform(0, Z) + end + end + + n = 2 + initial_z = 15 + initial_x = [0.2, 0.5] + model = constrained_uniform(n) + vi = VarInfo(model) + + @test_throws ArgumentError DynamicPPL.initialize_parameters!!( + vi, [initial_z, initial_x], DynamicPPL.SampleFromPrior(), model + ) + + @test_throws ArgumentError DynamicPPL.initialize_parameters!!( + vi, (X=initial_x, Z=initial_z), DynamicPPL.SampleFromPrior(), model + ) + end end end From 1d365296ff6c300cf3d92495caf42f23c95bb3a1 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Tue, 7 Jan 2025 11:19:56 +0000 Subject: [PATCH 2/7] Update src/sampler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sampler.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index 34671c928..4b6b695f3 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -204,7 +204,6 @@ function set_values!!( ) end end - throw(ArgumentError("Variable $v not found in the model.")) end end From cb6ecb74a3c95524e9e0a88739f81d5c97ab2f0d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 7 Jan 2025 11:45:04 +0000 Subject: [PATCH 3/7] fix logic error --- src/sampler.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 4b6b695f3..148255a70 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -193,9 +193,10 @@ function set_values!!( ) vars_in_varinfo = keys(varinfo) for v in keys(initial_params) - if !(v in vars_in_varinfo) + vn = VarName{v}() + if !(vn in vars_in_varinfo) for vv in vars_in_varinfo - if subsumes(VarName{v}(), vv) + if subsumes(vn, vv) throw( ArgumentError( "Variable $v not found in model, but it subsumes a variable ($vv) in the model. " * From ba4182c257ef5bb1ae6971dc7f4cf728bc68f809 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 7 Jan 2025 13:41:06 +0000 Subject: [PATCH 4/7] apply Will's suggestions --- src/sampler.jl | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 148255a70..a81479429 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -157,17 +157,21 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). initialsampler(spl::Sampler) = SampleFromPrior() function set_values!!( - varinfo::AbstractVarInfo, initial_params::AbstractVector{T}, spl::AbstractSampler -) where {T} - if T === Any - throw( - ArgumentError( - "`initial_params` must be a vector of type `Union{Real,Missing}`. " * - "If `initial_params` is a vector of vectors, please flatten it first using `vcat`.", - ), - ) - end + varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler +) + throw( + ArgumentError( + "`initial_params` must be a vector of type `Union{Real,Missing}`. " * + "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", + ), + ) +end +function set_values!!( + varinfo::AbstractVarInfo, + initial_params::AbstractVector{<:Union{Real,Missing}}, + spl::AbstractSampler, +) flattened_param_vals = varinfo[spl] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( @@ -199,7 +203,8 @@ function set_values!!( if subsumes(vn, vv) throw( ArgumentError( - "Variable $v not found in model, but it subsumes a variable ($vv) in the model. " * + "The current model does not contain variable $v, but there's ($vv) in the model. " * + "Using NamedTuple for initial_params is not supported for this model. " * "Please use AbstractVector for initial_params instead of NamedTuple.", ), ) From 072ed04cdfd2cbbb472d0d2c9fb01cdcccc98dec Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 7 Jan 2025 14:26:05 +0000 Subject: [PATCH 5/7] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 60dbcdc81..2bf60214f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.33.0" +version = "0.33.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From c90f80a3e4dfd64ee2eec0beaad4776bfe6fb5e6 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Tue, 7 Jan 2025 17:46:49 +0000 Subject: [PATCH 6/7] Update src/sampler.jl Co-authored-by: Penelope Yong --- src/sampler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index a81479429..974828e8b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -203,8 +203,8 @@ function set_values!!( if subsumes(vn, vv) throw( ArgumentError( - "The current model does not contain variable $v, but there's ($vv) in the model. " * - "Using NamedTuple for initial_params is not supported for this model. " * + "The current model contains sub-variables of $v, such as ($vv). " * + "Using NamedTuple for initial_params is not supported in such a case. " * "Please use AbstractVector for initial_params instead of NamedTuple.", ), ) From d267fc64f7f832fc2a71bb87c066aa6e643b11b7 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Tue, 7 Jan 2025 17:47:10 +0000 Subject: [PATCH 7/7] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2bf60214f..60dbcdc81 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.33.1" +version = "0.33.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"