diff --git a/src/flatten.jl b/src/flatten.jl index acd862ad..a305d1b5 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -68,7 +68,7 @@ const flattened_dists = [ Bernoulli, Truncated, ] for T in flattened_dists - @eval toflatten(::T) = true + @eval toflatten(::$T) = true end toflatten(::Distribution) = false for T in flattened_dists