Skip to content

Commit

Permalink
Fix DF indexing, fix dict mutability. (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpfiffer authored Sep 3, 2019
1 parent a97a085 commit 1b18d6c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
14 changes: 8 additions & 6 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,20 @@ end
function Chains(
val::AbstractArray{A,3},
parameter_names::Vector{String} = map(i->"Param$i", 1:size(val, 2)),
name_map = copy(DEFAULT_MAP);
name_map_original = copy(DEFAULT_MAP);
start::Int=1,
thin::Int=1,
evidence = missing,
info::NamedTuple=NamedTuple(),
sorted::Bool=true) where {A<:Union{Real, Union{Missing, Real}}}

# If we received an array of pairs, convert it to a dictionary.
if typeof(name_map) <: Array
name_map = Dict(name_map)
elseif typeof(name_map) <: NamedTuple
name_map = _namedtuple2dict(name_map)
name_map = if typeof(name_map_original) <: Dict
# Copying can avoid state mutation.
deepcopy(name_map_original)
elseif typeof(name_map_original) <: Array
Dict(deepcopy(name_map_original))
elseif typeof(name_map_original) <: NamedTuple
_namedtuple2dict(name_map_original)
end

# Make sure that we have a :parameters index.
Expand Down
6 changes: 3 additions & 3 deletions src/summarize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function Base.show(io::IO, c::ChainDataFrame)
end

Base.getindex(c::ChainDataFrame, args...) = getindex(c.df, args...)
Base.getindex(c::ChainDataFrame, s::Union{Symbol, Vector{Symbol}}) = c.df[s]
Base.getindex(c::ChainDataFrame, s::Union{Symbol, Vector{Symbol}}) = c.df[:, s]
Base.isequal(cs1::Vector{ChainDataFrame}, cs2::Vector{ChainDataFrame}) = isequal.(cs1, cs2)
Base.isequal(c1::ChainDataFrame, c2::ChainDataFrame) = isequal(c1, c2)

Expand Down Expand Up @@ -75,13 +75,13 @@ end
function Base.getindex(c::ChainDataFrame,
s1::Vector{Symbol},
s2::Union{Symbol, Vector{Symbol}})
return c.df[map(x -> x in s1, c.df.parameters), s2]
return c.df[map(x -> x in s1, c.df[:, :parameters]), s2]
end

function Base.getindex(c::ChainDataFrame,
s1::Symbol,
s2::Union{Symbol, Vector{Symbol}})
return c.df[c.df.parameters .== s1, s2]
return c.df[c.df[:, :parameters] .== s1, s2]
end


Expand Down

0 comments on commit 1b18d6c

Please sign in to comment.