diff --git a/Project.toml b/Project.toml index 715ebe800..79ea1b764 100644 --- a/Project.toml +++ b/Project.toml @@ -13,8 +13,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0" diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 1a5a4bcd0..da153d14e 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -1,5 +1,6 @@ # Structured matrices using LinearAlgebra: AbstractTriangular +using SparseInverseSubset # Matrix wrapper types that we know are square and are thus potentially invertible. For # these we can use simpler definitions for `/` and `\`. diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 51b41421c..8ee8f8cd0 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -49,3 +49,91 @@ function rrule(::typeof(findnz), v::AbstractSparseVector) return (I, V), findnz_pullback end + +if VERSION < v"1.7" + #= + The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log) + determinants of sparse matrices, but was not defined prior to Julia v1.7. In order + for the rrules for the determinants of sparse matrices below to work, they need to be + able to compute the primals as well, so this import from the future is included. For + more recent versions of Julia, this definition lives in: + julia/stdlib/SuiteSparse/src/umfpack.jl + =# + using SuiteSparse.UMFPACK: UmfpackLU + + # compute the sign/parity of a permutation + function _signperm(p) + n = length(p) + result = 0 + todo = trues(n) + while any(todo) + k = findfirst(todo) + todo[k] = false + result += 1 # increment element count + j = p[k] + while j != k + result += 1 # increment element count + todo[j] = false + j = p[j] + end + result += 1 # increment cycle count + end + return ifelse(isodd(result), -1, 1) + end + + function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}} + n = checksquare(F) + issuccess(F) || return log(zero(real(T))), zero(T) + U = F.U + Rs = F.Rs + p = F.p + q = F.q + s = _signperm(p)*_signperm(q)*one(real(T)) + P = one(T) + abs_det = zero(real(T)) + @inbounds for i in 1:n + dg_ii = U[i, i] / Rs[i] + P *= sign(dg_ii) + abs_det += log(abs(dg_ii)) + end + return abs_det, s * P + end +end + + +function rrule(::typeof(logabsdet), x::SparseMatrixCSC) + F = cholesky(x) + L, D, U, P = SparseInverseSubset.get_ldup(F) + Ω = logabsdet(D) + function logabsdet_pullback(ΔΩ) + (Δy, Δsigny) = ΔΩ + (_, signy) = Ω + f = signy' * Δsigny + imagf = f - real(f) + g = real(Δy) + imagf + Z, P = sparseinv(F, depermute=true) + ∂x = g * Z' + return (NoTangent(), ∂x) + end + return Ω, logabsdet_pullback +end + +function rrule(::typeof(logdet), x::SparseMatrixCSC) + Ω = logdet(x) + function logdet_pullback(ΔΩ) + Z, p = sparseinv(x, depermute=true) + ∂x = ΔΩ * Z' + return (NoTangent(), ∂x) + end + return Ω, logdet_pullback +end + +function rrule(::typeof(det), x::SparseMatrixCSC) + Ω = det(x) + function det_pullback(ΔΩ) + Z, _ = sparseinv(x, depermute=true) + ∂x = Z' * dot(Ω, ΔΩ) + return (NoTangent(), ∂x) + end + return Ω, det_pullback +end diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index a11a1e963..03f1052c2 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -33,3 +33,13 @@ end V̄ = rand!(similar(V)) test_rrule(findnz, v ⊢ dv, output_tangent=(zeros(length(I)), V̄)) end + +@testset "[log[abs[det]]] SparseMatrixCSC" begin + ii = [1:5; 2; 4] + jj = [1:5; 4; 2] + x = [ones(5); 0.1; 0.1] + A = sparse(ii, jj, x) + test_rrule(logabsdet, A) + test_rrule(logdet, A) + test_rrule(det, A) +end \ No newline at end of file