From 14a292abdab7e36eb7c0fb6cdacc444b565987f6 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 8 Aug 2024 12:07:21 -0400 Subject: [PATCH] work around stack --- src/Enzyme.jl | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index fd6510b18e..c33b3a16d1 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1243,9 +1243,13 @@ of shape `size(input)` of values of the output type. inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = Base.stack(cols) + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end - st3 = if length(inshape) <= 1 + st3 = if length(inshape) <= 1 || VERSION < v"1.9" st else reshape(st, (outshape..., inshape...)) @@ -1275,9 +1279,13 @@ end inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = Base.stack(cols) + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end - st3 = if length(inshape) <= 1 + st3 = if length(inshape) <= 1 || VERSION < v"1.9" st else reshape(st, (outshape..., inshape...)) @@ -1303,9 +1311,13 @@ end inshape = size(x) outshape = size(cols[1]) # st : outshape x total inputs - st = Base.stack(cols) + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end - st3 = if length(inshape) <= 1 + st3 = if length(inshape) <= 1 || VERSION < v"1.9" st else reshape(st, (outshape..., inshape...)) @@ -1402,8 +1414,14 @@ of shape `size(output)` of values of the input type. outshape = tmp[1][2] if x isa AbstractArray inshape = size(x) - st = Base.stack(rows) - st2 = if length(outshape) == 1 + + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end + + st2 = if length(outshape) == 1 || VERSION < v"1.9" st else reshape(st, (inshape..., outshape...)) @@ -1450,8 +1468,13 @@ end outshape = tmp[1][2] if x isa AbstractArray inshape = size(x) - st = Base.stack(rows) - st2 = if length(outshape) == 1 + st = @static if VERSION >= v"1.9" + Base.stack(cols) + else + reshape(cat(cols..., dims=length(outshape)), (outshape..., inshape...)) + end + + st2 = if length(outshape) == 1 || VERSION < v"1.9" st else reshape(st, (inshape..., outshape...))