diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 89d280134..439aef762 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -40,7 +40,10 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) end, # out-of-place versions @thunk(if isempty(x) || p == 0 - zero.(x) .* (zero(y) * zero(real(Δy))) + # Note: post-julia-1.11 the zero.(Diagonal(Float64[;])) .* 0.0) + # only infers down to Union(Diagonal{Float64}, Matrix{Float64}) + # rather than Diagonal{Float64}, so we cast it back. + maybe_withsomezeros_rewrap(x, zero.(x) .* (zero(y) * zero(real(Δy)))) elseif p == 2 _norm2_back(x, y, Δy) elseif p == 1 @@ -72,7 +75,10 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}) end , @thunk(if isempty(x) - zero.(x) .* (zero(y) * zero(real(Δy))) + # Note: post-julia-1.11 the zero.(Diagonal(Float64[;])) .* 0.0) + # only infers down to Union(Diagonal{Float64}, Matrix{Float64}) + # rather than Diagonal{Float64}, so we cast it back. + maybe_withsomezeros_rewrap(x, zero.(x) .* (zero(y) * zero(real(Δy)))) else _norm2_back(x, y, Δy) end) diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 3d8ad923f..8d6ac72fa 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -58,7 +58,9 @@ for S in [ :UnitLowerTriangular, ] @eval withsomezeros_rewrap(::$S, x) = $S(x) + @eval maybe_withsomezeros_rewrap(::$S, x) = $S(x) end +maybe_withsomezeros_rewrap(::AbstractArray, x) = x # Bidiagonal, Tridiagonal have more complicated storage. # AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent()))