From 225dd72d43359cf4b6092975f3f7f2cdf2a8f788 Mon Sep 17 00:00:00 2001 From: Per Rutquist Date: Tue, 5 Dec 2023 22:44:50 +0100 Subject: [PATCH] Rewrite svd_rev to reduce allocations --- src/rulesets/LinearAlgebra/factorization.jl | 33 ++++++++++----------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index cd73654b5..a1bb400a0 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -243,7 +243,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD return getproperty(F, x), getproperty_svd_pullback end -# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix` +# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄t::AbstractMatrix` function svd_rev(USV::SVD, Ū, s̄, V̄t) # Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default U = USV.U @@ -252,25 +252,24 @@ function svd_rev(USV::SVD, Ū, s̄, V̄t) k = length(s) T = eltype(s) - F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k] - - UtŪ = U'*Ū - V̄tV = V̄t*Vt' - - FUᵀŪS = F .* (UtŪ .- UtŪ') .* s' - SFVᵀV̄ = F .* (V̄tV' .- V̄tV) .* s - - S = Diagonal(s) - S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄) - - if size(Vt,1) == size(Vt,2) + UtŪ = U' * Ū + V̄tV = V̄t * Vt' + M = @inbounds T[ + if i == j + s̄[i] + else + (s[j] * (UtŪ[i, j] - UtŪ[j, i]) + s[i] * (V̄tV[j, i] - V̄tV[i, j])) / + (s[j]^2 - s[i]^2) + end for i in 1:k, j in 1:k + ] + + if size(Vt, 1) == size(Vt, 2) # V is square, VVᵀ = I and therefore V̄ᵀ - V̄ᵀVVᵀ = 0 - Ā = (U * (FUᵀŪS + S̄ + SFVᵀV̄) + ((Ū .- U * UtŪ) / S)) * Vt - else + Ā = (U * M .+ ((Ū .- U * UtŪ) ./ s')) * Vt + else # If V is not square then U is, so UUᵀ == I and Ū - UUᵀŪ = 0 - Ā = U * ((FUᵀŪS + S̄ + SFVᵀV̄) * Vt + (S \ (V̄t .- V̄tV * Vt))) + Ā = U * (M * Vt .+ ((V̄t .- V̄tV * Vt) ./ s)) end - return Ā end