Skip to content

Commit

Permalink
Special-case ReshapeTransform for singleton inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Oct 25, 2024
1 parent bd9f465 commit db391a2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.30.1"
version = "0.30.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
11 changes: 9 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,15 @@ function (f::ReshapeTransform)(x)
if size(x) != f.input_size
throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))"))
end
# The call to `tovec` is only needed in case `x` is a scalar.
return reshape(tovec(x), f.output_size)
if f.output_size == ()
# Specially handle the case where x is a singleton array, see
# https://github.com/JuliaDiff/ReverseDiff.jl/issues/265 and
# https://github.com/TuringLang/DynamicPPL.jl/issues/698
return x[]
else
# The call to `tovec` is only needed in case `x` is a scalar.
return reshape(tovec(x), f.output_size)
end
end

function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x)
Expand Down

0 comments on commit db391a2

Please sign in to comment.