diff --git a/src/chains.jl b/src/chains.jl index c5966c51..6138bbbc 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -54,18 +54,20 @@ end function Chains( val::AbstractArray{A,3}, parameter_names::Vector{String} = map(i->"Param$i", 1:size(val, 2)), - name_map = copy(DEFAULT_MAP); + name_map_original = copy(DEFAULT_MAP); start::Int=1, thin::Int=1, evidence = missing, info::NamedTuple=NamedTuple(), sorted::Bool=true) where {A<:Union{Real, Union{Missing, Real}}} - # If we received an array of pairs, convert it to a dictionary. - if typeof(name_map) <: Array - name_map = Dict(name_map) - elseif typeof(name_map) <: NamedTuple - name_map = _namedtuple2dict(name_map) + name_map = if typeof(name_map_original) <: Dict + # Copying can avoid state mutation. + deepcopy(name_map_original) + elseif typeof(name_map_original) <: Array + Dict(deepcopy(name_map_original)) + elseif typeof(name_map_original) <: NamedTuple + _namedtuple2dict(name_map_original) end # Make sure that we have a :parameters index. diff --git a/src/summarize.jl b/src/summarize.jl index d043dcb5..2d87a140 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -47,7 +47,7 @@ function Base.show(io::IO, c::ChainDataFrame) end Base.getindex(c::ChainDataFrame, args...) = getindex(c.df, args...) -Base.getindex(c::ChainDataFrame, s::Union{Symbol, Vector{Symbol}}) = c.df[s] +Base.getindex(c::ChainDataFrame, s::Union{Symbol, Vector{Symbol}}) = c.df[:, s] Base.isequal(cs1::Vector{ChainDataFrame}, cs2::Vector{ChainDataFrame}) = isequal.(cs1, cs2) Base.isequal(c1::ChainDataFrame, c2::ChainDataFrame) = isequal(c1, c2) @@ -75,13 +75,13 @@ end function Base.getindex(c::ChainDataFrame, s1::Vector{Symbol}, s2::Union{Symbol, Vector{Symbol}}) - return c.df[map(x -> x in s1, c.df.parameters), s2] + return c.df[map(x -> x in s1, c.df[:, :parameters]), s2] end function Base.getindex(c::ChainDataFrame, s1::Symbol, s2::Union{Symbol, Vector{Symbol}}) - return c.df[c.df.parameters .== s1, s2] + return c.df[c.df[:, :parameters] .== s1, s2] end