Skip to content

Commit

Permalink
Fix Array conversion of ChainDataFrame (#188)
Browse files Browse the repository at this point in the history
* Drop column with parameter names with undefined values in Array conversion

* Allow to specify lags as AbstractVector

* Always forward sections keyword argument to summarize

* Test plotting with UnicodePlots

* Update Project.toml
  • Loading branch information
devmotion authored Feb 27, 2020
1 parent c37c1e3 commit 806ccf2
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 62 deletions.
4 changes: 0 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ julia:
- 1.3
- nightly

# deactivate plot display
env:
- GKSwstype=nul

matrix:
allow_failures:
- julia: nightly
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = " Chain types and utility functions for MCMC simulations."
version = "3.0.2"
version = "3.0.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -33,8 +33,10 @@ julia = "^1"
[extras]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[targets]
test = ["KernelDensity", "StatsPlots", "Test", "DataFrames"]
test = ["DataFrames", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots"]
4 changes: 2 additions & 2 deletions src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
end

if st == :autocorplot
lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
ac = autocor(c, lags=collect(lags); showall=true)
lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
ac = autocor(c, lags=lags; showall=true)
ac_mat = convert(Array, ac)
val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
_AutocorPlot(lags, val)
Expand Down
12 changes: 7 additions & 5 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The `digits` keyword may be a(n)
- `Dict`, with a similar structure as `NamedTuple`. `Dict(mean => 2, std => 3)` would set `mean` to two digits and `std` to three digits.
"""
function autocor(chn::Chains;
lags::Vector=[1, 5, 10, 50],
lags::AbstractVector{<:Integer}=[1, 5, 10, 50],
demean::Bool=true,
relative::Bool=true,
showall=false,
Expand All @@ -33,6 +33,7 @@ function autocor(chn::Chains;
return summarize(chn, funs...;
func_names = func_names,
showall = showall,
sections = sections,
append_chains = append_chains,
name = "Autocorrelation",
digits=digits)
Expand Down Expand Up @@ -200,6 +201,7 @@ function hpd(chn::Chains; alpha::Real=0.05,
return summarize(chn, u, l;
func_names = labels,
showall=showall,
sections=sections,
name="HPD",
digits=digits)
end
Expand Down Expand Up @@ -235,9 +237,9 @@ function quantile(chn::Chains;
end

return summarize(chn, funs...;
sections=sections,
func_names=func_names,
showall=showall,
sections=sections,
name="Quantiles",
digits=digits,
append_chains=append_chains,
Expand Down Expand Up @@ -285,7 +287,7 @@ function ess(chn::Chains;
# Misc allocations.
m = n_chain_orig * 2
maxlag = min(maxlag, n-1)
lags = collect(0:maxlag)
lags = 0:maxlag

# Preallocate B, W, varhat, and Rhat vectors for each param.
B = Vector(undef, length(param))
Expand Down Expand Up @@ -411,9 +413,9 @@ function summarystats(chn::Chains;

# Summarize.
summary_df = summarize(chn, funs...;
sections=sections,
func_names=func_names,
showall=showall,
sections=sections,
name="Summary Statistics",
additional_df = ess_df,
digits=digits,
Expand Down Expand Up @@ -450,9 +452,9 @@ function mean(chn::Chains;

# Summarize.
summary_df = summarize(chn, funs...;
sections=sections,
func_names=func_names,
showall=showall,
sections=sections,
name="Mean",
digits=digits)

Expand Down
14 changes: 8 additions & 6 deletions src/summarize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,23 @@ function Base.lastindex(c::ChainDataFrame, i::Integer)
end

function Base.convert(::Type{Array}, c::C) where C<:ChainDataFrame
arr = Array{Float64, 2}(undef, c.n_rows, c.n_cols)
ks = collect(keys(c.nt))
T = promote_eltype_namedtuple_tail(c.nt)
arr = Array{T, 2}(undef, c.n_rows, c.n_cols - 1)

for i in 2:c.n_cols
arr[:, i] = c.nt[ks[i]]
for (i, k) in enumerate(Iterators.drop(keys(c.nt), 1))
arr[:, i] = c.nt[k]
end

return arr
end

Base.convert(::Type{Array{ChainDataFrame,1}}, cs::Array{ChainDataFrame,1}) = cs
function Base.convert(::Type{Array}, cs::Array{C,1}) where C<:ChainDataFrame
arrs = [convert(Array, cs[j]) for j in 1:length(cs)]
return cat(arrs..., dims = 3)
return mapreduce((x, y) -> cat(x, y; dims = Val(3)), cs) do c
reshape(convert(Array, c), Val(3))
end
end

"""
# Summarize a Chains object formatted as a DataFrame
Expand Down
12 changes: 12 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ Base.@pure function merge_union_types(names::Tuple{Vararg{Symbol}}, a::Type{<:Na
return Tuple{types...}
end

# promote element types of the tail of a NamedTuple
function promote_eltype_namedtuple_tail(::NamedTuple{k,v}) where {k,v}
return promote_eltype_tuple_type(Base.tuple_type_tail(v))
end

# promote element types of a tuple
promote_eltype_tuple_type(::Type{Tuple{}}) = Any
promote_eltype_tuple_type(::Type{Tuple{T}}) where T = T
function promote_eltype_tuple_type(t::Type{<:Tuple})
Base.promote_eltype(Base.tuple_type_head(t), promote_eltype_tuple_type(Base.tuple_type_tail(t)))
end

"""
cskip(x)
Expand Down
81 changes: 38 additions & 43 deletions test/plot_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ using Test
using StatsPlots
using MCMCChains

import Logging

unicodeplots()

n_iter = 500
n_name = 3
n_chain = 2
Expand All @@ -11,60 +15,51 @@ val = hcat(val, rand(1:2, n_iter, 1, n_chain))

chn = Chains(val)

@testset "Plotting tests" begin
# Silence all warnings.
level = Logging.min_enabled_level(Logging.current_logger())
Logging.disable_logging(Logging.Warn)

@testset "Plotting tests" begin
# plotting singe plotting types
println("traceplot")
ps_trace = traceplot(chn, 1)
@test isa(ps_trace, Plots.Plot)

display(traceplot(chn, 1))
println()
println("meanplot")
ps_mean = meanplot(chn, 1)
@test isa(ps_mean, Plots.Plot)

display(meanplot(chn, 1))
println()
println("density")
ps_density = density(chn, 1)
@test isa(ps_density, Plots.Plot)

ps_density = density(chn, 1, append_chains=true)
@test isa(ps_density, Plots.Plot)

display(density(chn, 1))
display(density(chn, 1, append_chains=true))
println()

println("autocorplot")
ps_autocor = autocorplot(chn, 1)
@test isa(ps_autocor, Plots.Plot)

display(autocorplot(chn, 1))
println()
#ps_contour = plot(chn, :contour)

println("histogram")
ps_hist = histogram(chn, 1)
@test isa(ps_hist, Plots.Plot)

println("mixeddensity")
ps_mixed = mixeddensity(chn, 1)
@test isa(ps_mixed, Plots.Plot)

display(histogram(chn, 1))
println()

println("\nmixeddensity")
display(mixeddensity(chn, 1))

# plotting combinations
ps_trace_mean = plot(chn)
@test isa(ps_trace_mean, Plots.Plot)

ps_trace_mean = plot(chn, append_chains=true)
@test isa(ps_trace_mean, Plots.Plot)

savefig("demo-plot.png")

ps_mixed_auto = plot(chn, seriestype = (:mixeddensity, :autocorplot))
@test isa(ps_mixed_auto, Plots.Plot)
display(plot(chn))
display(plot(chn, append_chains=true))
display(plot(chn, seriestype = (:mixeddensity, :autocorplot)))

# Test plotting using colordim keyword
p_colordim = plot(chn, colordim = :parameter)
@test isa(p_colordim, Plots.Plot)

display(plot(chn, colordim = :parameter))

# Test if plotting a sub-set work.s
p_subset = plot(chn, 2)
@test isa(p_subset, Plots.Plot)

p_subset_colordim = plot(chn, 2, colordim = :parameter)
@test isa(p_subset_colordim, Plots.Plot)

rm("demo-plot.png")
display(plot(chn, 2))
display(plot(chn, 2, colordim = :parameter))
println()
end

# Reset log level.
Logging.disable_logging(level)

0 comments on commit 806ccf2

Please sign in to comment.