diff --git a/src/common.jl b/src/common.jl index 6423118..d5fecf3 100644 --- a/src/common.jl +++ b/src/common.jl @@ -223,23 +223,21 @@ end function _predictrBCM(node::GPSplitNode, x::AbstractMatrix) μ = zeros(size(x,1)) - t = zeros(size(x,1)) - β = zeros(size(x,1)) gp = leftGP(node) s = diag(kernelmatrix(gp.kernel, x, x)) .+ getnoise(gp) + C = deepcopy(1 ./ s) + for (k,c) in enumerate(children(node)) μ_, t_ = _predictPoE(c, x) - β_ = 0.5 * (log.(s) - log.(inv.(t_))) - t[:] += β_.*t_ - μ[:] += β_ .* t_ .* μ_ - β[:] += β_ + s_ = 1 ./ t_ + β_ = 0.5 * (log.(s) - log.(s_)) + C += (β_ .* t_) - (β_ ./ s) + μ += μ_ .* (β_ .* t_) end - z = (1 .- β) ./ s - t += z - t[t .<= 0] .= 1e-8 - return μ ./ t, t + + return μ ./ C, C end function predict(node::GPSplitNode, x::AbstractMatrix)