diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 245335774..4de6b9a5b 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -1,6 +1,5 @@ # Structured matrices using LinearAlgebra: AbstractTriangular -using SparseInverseSubset # Matrix wrapper types that we know are square and are thus potentially invertible. For # these we can use simpler definitions for `/` and `\`. diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 06de7a135..0f16fd6db 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -50,94 +50,99 @@ function rrule(::typeof(findnz), v::AbstractSparseVector) return (I, V), findnz_pullback end -if VERSION < v"1.7" - #= - The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log) - determinants of sparse matrices, but was not defined prior to Julia v1.7. In order - for the rrules for the determinants of sparse matrices below to work, they need to be - able to compute the primals as well, so this import from the future is included. For - more recent versions of Julia, this definition lives in: - julia/stdlib/SuiteSparse/src/umfpack.jl - =# - using SuiteSparse.UMFPACK: UmfpackLU - - # compute the sign/parity of a permutation - function _signperm(p) - n = length(p) - result = 0 - todo = trues(n) - while any(todo) - k = findfirst(todo) - todo[k] = false - result += 1 # increment element count - j = p[k] - while j != k +if Base.USE_GPL_LIBS # Don't define rrules for sparse determinants if we don't have CHOLMOD from SuiteSparse.jl + + if VERSION < v"1.7" + #= + The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log) + determinants of sparse matrices, but was not defined prior to Julia v1.7. In order + for the rrules for the determinants of sparse matrices below to work, they need to be + able to compute the primals as well, so this import from the future is included. For + more recent versions of Julia, this definition lives in: + julia/stdlib/SuiteSparse/src/umfpack.jl + =# + using SuiteSparse.UMFPACK: UmfpackLU + + # compute the sign/parity of a permutation + function _signperm(p) + n = length(p) + result = 0 + todo = trues(n) + while any(todo) + k = findfirst(todo) + todo[k] = false result += 1 # increment element count - todo[j] = false - j = p[j] + j = p[k] + while j != k + result += 1 # increment element count + todo[j] = false + j = p[j] + end + result += 1 # increment cycle count end - result += 1 # increment cycle count + return ifelse(isodd(result), -1, 1) end - return ifelse(isodd(result), -1, 1) - end - function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}} - n = checksquare(F) - issuccess(F) || return log(zero(real(T))), zero(T) - U = F.U - Rs = F.Rs - p = F.p - q = F.q - s = _signperm(p)*_signperm(q)*one(real(T)) - P = one(T) - abs_det = zero(real(T)) - @inbounds for i in 1:n - dg_ii = U[i, i] / Rs[i] - P *= sign(dg_ii) - abs_det += log(abs(dg_ii)) + using SparseInverseSubset + + function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}} + n = checksquare(F) + issuccess(F) || return log(zero(real(T))), zero(T) + U = F.U + Rs = F.Rs + p = F.p + q = F.q + s = _signperm(p)*_signperm(q)*one(real(T)) + P = one(T) + abs_det = zero(real(T)) + @inbounds for i in 1:n + dg_ii = U[i, i] / Rs[i] + P *= sign(dg_ii) + abs_det += log(abs(dg_ii)) + end + return abs_det, s * P end - return abs_det, s * P end -end - - -function rrule(::typeof(logabsdet), x::SparseMatrixCSC) - F = cholesky(x) - L, D, U, P = SparseInverseSubset.get_ldup(F) - Ω = logabsdet(D) - function logabsdet_pullback(ΔΩ) - (Δy, Δsigny) = ΔΩ - (_, signy) = Ω - f = signy' * Δsigny - imagf = f - real(f) - g = real(Δy) + imagf - Z, P = sparseinv(F, depermute=true) - ∂x = g * Z' - return (NoTangent(), ∂x) + + + function rrule(::typeof(logabsdet), x::SparseMatrixCSC) + F = cholesky(x) + L, D, U, P = SparseInverseSubset.get_ldup(F) + Ω = logabsdet(D) + function logabsdet_pullback(ΔΩ) + (Δy, Δsigny) = ΔΩ + (_, signy) = Ω + f = signy' * Δsigny + imagf = f - real(f) + g = real(Δy) + imagf + Z, P = sparseinv(F, depermute=true) + ∂x = g * Z' + return (NoTangent(), ∂x) + end + return Ω, logabsdet_pullback end - return Ω, logabsdet_pullback -end - -function rrule(::typeof(logdet), x::SparseMatrixCSC) - Ω = logdet(x) - function logdet_pullback(ΔΩ) - Z, p = sparseinv(x, depermute=true) - ∂x = ΔΩ * Z' - return (NoTangent(), ∂x) + + function rrule(::typeof(logdet), x::SparseMatrixCSC) + Ω = logdet(x) + function logdet_pullback(ΔΩ) + Z, p = sparseinv(x, depermute=true) + ∂x = ΔΩ * Z' + return (NoTangent(), ∂x) + end + return Ω, logdet_pullback end - return Ω, logdet_pullback -end - -function rrule(::typeof(det), x::SparseMatrixCSC) - Ω = det(x) - function det_pullback(ΔΩ) - Z, _ = sparseinv(x, depermute=true) - ∂x = Z' * dot(Ω, ΔΩ) - return (NoTangent(), ∂x) + + function rrule(::typeof(det), x::SparseMatrixCSC) + Ω = det(x) + function det_pullback(ΔΩ) + Z, _ = sparseinv(x, depermute=true) + ∂x = Z' * dot(Ω, ΔΩ) + return (NoTangent(), ∂x) + end + return Ω, det_pullback end - return Ω, det_pullback -end - + +end # rrules that depend on CHOLMOD function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index 283452a8a..ea0cf5199 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -79,12 +79,14 @@ end test_rrule(findnz, v ⊢ dv, output_tangent=(zeros(length(I)), V̄)) end -@testset "[log[abs[det]]] SparseMatrixCSC" begin - ii = [1:5; 2; 4] - jj = [1:5; 4; 2] - x = [ones(5); 0.1; 0.1] - A = sparse(ii, jj, x) - test_rrule(logabsdet, A) - test_rrule(logdet, A) - test_rrule(det, A) +if Base.USE_GPL_LIBS # these rrules don't work without CHOLMOD from SuiteSparse.jl + @testset "[log[abs[det]]] SparseMatrixCSC" begin + ii = [1:5; 2; 4] + jj = [1:5; 4; 2] + x = [ones(5); 0.1; 0.1] + A = sparse(ii, jj, x) + test_rrule(logabsdet, A) + test_rrule(logdet, A) + test_rrule(det, A) + end end