From 90d99ca37b646bf19377f570f96dbb04e3688274 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 18 Nov 2024 12:18:17 +0000 Subject: [PATCH] fix test error --- test/ext/DynamicPPLMCMCChainsExt.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 4284edcfb..8cdcbfd92 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -75,10 +75,12 @@ end DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), AdvancedHMC.NUTS(0.65), MCMCThreads(), - 200, + 1000, 2; chain_type=MCMCChains.Chains, param_names=[:β], + discard_initial=100, + n_adapt=100, ) m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) @@ -158,9 +160,10 @@ end [simple_linear1, simple_linear2, simple_linear3, simple_linear4] m = model(x, y) chain = sample( - DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)), + DynamicPPL.LogDensityFunction(m), AdvancedHMC.NUTS(0.65), - 1000; + 400; + initial_params = rand(4), chain_type=MCMCChains.Chains, param_names=param_names[model], discard_initial=100,