Skip to content

Commit

Permalink
work around stack
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Aug 8, 2024
1 parent a529690 commit 14a292a
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1243,9 +1243,13 @@ of shape `size(input)` of values of the output type.
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
st = Base.stack(cols)
st = @static if VERSION >= v"1.9"
Base.stack(cols)
else
reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...))
end

st3 = if length(inshape) <= 1
st3 = if length(inshape) <= 1 || VERSION < v"1.9"
st
else
reshape(st, (outshape..., inshape...))
Expand Down Expand Up @@ -1275,9 +1279,13 @@ end
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
st = Base.stack(cols)
st = @static if VERSION >= v"1.9"
Base.stack(cols)
else
reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...))
end

st3 = if length(inshape) <= 1
st3 = if length(inshape) <= 1 || VERSION < v"1.9"
st
else
reshape(st, (outshape..., inshape...))
Expand All @@ -1303,9 +1311,13 @@ end
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
st = Base.stack(cols)
st = @static if VERSION >= v"1.9"
Base.stack(cols)
else
reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...))
end

st3 = if length(inshape) <= 1
st3 = if length(inshape) <= 1 || VERSION < v"1.9"
st
else
reshape(st, (outshape..., inshape...))
Expand Down Expand Up @@ -1402,8 +1414,14 @@ of shape `size(output)` of values of the input type.
outshape = tmp[1][2]
if x isa AbstractArray
inshape = size(x)
st = Base.stack(rows)
st2 = if length(outshape) == 1

st = @static if VERSION >= v"1.9"
Base.stack(cols)
else
reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...))
end

st2 = if length(outshape) == 1 || VERSION < v"1.9"
st
else
reshape(st, (inshape..., outshape...))
Expand Down Expand Up @@ -1450,8 +1468,13 @@ end
outshape = tmp[1][2]
if x isa AbstractArray
inshape = size(x)
st = Base.stack(rows)
st2 = if length(outshape) == 1
st = @static if VERSION >= v"1.9"
Base.stack(cols)
else
reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...))
end

st2 = if length(outshape) == 1 || VERSION < v"1.9"
st
else
reshape(st, (inshape..., outshape...))
Expand Down

0 comments on commit 14a292a

Please sign in to comment.