From 01d56d52d145e85fbe96444b7b1b96730a46791a Mon Sep 17 00:00:00 2001 From: Paulina Martin Date: Thu, 8 Jul 2021 16:05:57 -0500 Subject: [PATCH] Correct input variables of _ViolinPlot --- src/plot.jl | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index c95e6798..b9bf96dd 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -190,7 +190,7 @@ end @recipe function f( chains::Chains; - sections = chains.name_map[:parameters], + sections::Vector{Symbol} = chains.name_map[:parameters], combined = true ) @@ -198,17 +198,16 @@ end total_chains = 0 if st == :violinplot if combined - parameters = string.(sections) - val = Array(chains)[:, ] + n_iter, n_parameters = size(Array(chains)) + parameters = string.(repeat(sections, inner = n_iter)) + val = vec(Array(chains)) total_chains = Integer(size(chains.value.data)[3]) _ViolinPlot(parameters, val, total_chains) elseif combined == false + n_parameters = length(sections) chain_arr = Array(chains, append_chains = false) - parameters = ["param $(sections[i]).Chain $j" - for i in 1:length(sections) - for j in 1:length(chain_arr)] val_vec = [chain_arr[j][:,i] - for i in 1:length(sections) + for i in 1:n_parameters for j in 1:length(chain_arr)] n_iter = length(val_vec[1]) total_chains = length(val_vec) @@ -216,7 +215,12 @@ end for i in 1:total_chains val[:,i] = val_vec[:][i] end - _ViolinPlot(parameters, val[:,], total_chains) + val = vec(val) + parameters_names = ["param $(sections[i]).Chain $j" + for i in 1:n_parameters + for j in 1:length(chain_arr)] + parameters = string.(repeat(parameters_names, inner = n_iter)) + _ViolinPlot(parameters, val, total_chains) else error("Symbol names are interpreted as parameter names, only compatible with ", "`colordim = :chain`") @@ -228,7 +232,7 @@ end @series begin seriestype := :violin xaxis --> "Parameter" - size --> (150*p.total_chains, 500) + size --> (200*p.total_chains, 500) p.parameters, p.val end