diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 7e1d131116c..e60ca4afc08 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,6 +32,17 @@ end end end +@inline function Enzyme.EnzymeCore.make_zero( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}} + return Base.zero(prev)::FT +end +@inline function Enzyme.EnzymeCore.make_zero( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + return Base.zero(prev)::FT +end + @inline function Enzyme.EnzymeCore.make_zero( ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} @@ -47,6 +58,7 @@ end seen[prev] = new return new end + @inline function Enzyme.EnzymeCore.make_zero!( prev::FT, seen ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} @@ -62,7 +74,8 @@ end @inline function Enzyme.EnzymeCore.make_zero!( prev::FT ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - return Enzyme.EnzymeCore.make_zero!(prev, nothing) + Enzyme.EnzymeCore.make_zero!(prev, nothing) + return nothing end end