Skip to content

Commit

Permalink
Rewrite svd_rev to reduce allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
perrutquist committed Dec 5, 2023
1 parent ef235e1 commit 225dd72
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
return getproperty(F, x), getproperty_svd_pullback
end

# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, ::AbstractMatrix`
# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄t::AbstractMatrix`
function svd_rev(USV::SVD, Ū, s̄, V̄t)
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
U = USV.U
Expand All @@ -252,25 +252,24 @@ function svd_rev(USV::SVD, Ū, s̄, V̄t)

k = length(s)
T = eltype(s)
F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k]

UtŪ = U'*Ū
V̄tV = V̄t*Vt'

FUᵀŪS = F .* (UtŪ .- UtŪ') .* s'
SFVᵀV̄ = F .* (V̄tV' .- V̄tV) .* s

S = Diagonal(s)
=isa AbstractZero ?: Diagonal(s̄)
if size(Vt,1) == size(Vt,2)
UtŪ = U' * Ū
V̄tV = V̄t * Vt'
M = @inbounds T[
if i == j
s̄[i]
else
(s[j] * (UtŪ[i, j] - UtŪ[j, i]) + s[i] * (V̄tV[j, i] - V̄tV[i, j])) /
(s[j]^2 - s[i]^2)
end for i in 1:k, j in 1:k
]

if size(Vt, 1) == size(Vt, 2)
# V is square, VVᵀ = I and therefore V̄ᵀ - V̄ᵀVVᵀ = 0
Ā = (U * (FUᵀŪS ++ SFVᵀV̄) + ((Ū .- U * UtŪ) / S)) * Vt
else
Ā = (U * M .+ ((Ū .- U * UtŪ) ./ s')) * Vt
else
# If V is not square then U is, so UUᵀ == I and Ū - UUᵀŪ = 0
Ā = U * ((FUᵀŪS ++ SFVᵀV̄) * Vt + (S \ (V̄t .- V̄tV * Vt)))
Ā = U * (M * Vt .+ ((V̄t .- V̄tV * Vt) ./ s))
end

return Ā
end

Expand Down

0 comments on commit 225dd72

Please sign in to comment.