From 3577abd973714c22473b2d1d0ec1f2051afea4df Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Wed, 9 Aug 2023 10:25:45 -0700 Subject: [PATCH 1/7] det, logdet, logabsdet rrules for SparseMatrixCSC --- Project.toml | 1 + src/rulesets/LinearAlgebra/structured.jl | 38 +++++++++++++++++++++++ test/rulesets/LinearAlgebra/structured.jl | 10 ++++++ 3 files changed, 49 insertions(+) diff --git a/Project.toml b/Project.toml index d8f874245..ccd3be041 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 1a5a4bcd0..d7afa6a5d 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -1,5 +1,6 @@ # Structured matrices using LinearAlgebra: AbstractTriangular +using SparseInverseSubset # Matrix wrapper types that we know are square and are thus potentially invertible. For # these we can use simpler definitions for `/` and `\`. @@ -267,3 +268,40 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) end return y, logdet_pullback end + +function rrule(::typeof(logabsdet), x::SparseMatrixCSC) + F = cholesky(x) + L, D, U, P = SparseInverseSubset.get_ldup(F) + Ω = logabsdet(D) + function logabsdet_pullback(ΔΩ) + (Δy, Δsigny) = ΔΩ + (_, signy) = Ω + f = signy' * Δsigny + imagf = f - real(f) + g = real(Δy) + imagf + Z, P = sparseinv(F, depermute=true) + ∂x = g * Z' + return (NoTangent(), ∂x) + end + return Ω, logabsdet_pullback +end + +function rrule(::typeof(logdet), x::SparseMatrixCSC) + Ω = logdet(x) + function logdet_pullback(ΔΩ) + Z, p = sparseinv(x, depermute=true) + ∂x = x isa Number ? ΔΩ / x' : ΔΩ * Z' + return (NoTangent(), ∂x) + end + return Ω, logdet_pullback +end + +function rrule(::typeof(det), x::SparseMatrixCSC) + Ω = det(x) + function det_pullback(ΔΩ) + Z, _ = sparseinv(x, depermute=true) + ∂x = x isa Number ? ΔΩ : Z' * dot(Ω, ΔΩ) + return (NoTangent(), ∂x) + end + return Ω, det_pullback +end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 1b83cc394..530b4b7ad 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -161,4 +161,14 @@ end end end + + @testset "[log[abs[det]]] SparseMatrixCSC" begin + ii = 1:5 + jj = 1:5 + x = ones(5) + A = sparse(ii, jj, x) + test_rrule(logabsdet, A) + test_rrule(logdet, A) + test_rrule(det, A) + end end From 378aea2c351cdfc69d0d1be044deab633d154739 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Wed, 9 Aug 2023 11:33:22 -0700 Subject: [PATCH 2/7] remove unneccesary Number checks --- src/rulesets/LinearAlgebra/structured.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index d7afa6a5d..6fa6ad36b 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -290,7 +290,7 @@ function rrule(::typeof(logdet), x::SparseMatrixCSC) Ω = logdet(x) function logdet_pullback(ΔΩ) Z, p = sparseinv(x, depermute=true) - ∂x = x isa Number ? ΔΩ / x' : ΔΩ * Z' + ∂x = ΔΩ * Z' return (NoTangent(), ∂x) end return Ω, logdet_pullback @@ -300,7 +300,7 @@ function rrule(::typeof(det), x::SparseMatrixCSC) Ω = det(x) function det_pullback(ΔΩ) Z, _ = sparseinv(x, depermute=true) - ∂x = x isa Number ? ΔΩ : Z' * dot(Ω, ΔΩ) + ∂x = Z' * dot(Ω, ΔΩ) return (NoTangent(), ∂x) end return Ω, det_pullback From d5f015b90e45c9a92fe60897716aabf3b49bdd07 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Tue, 15 Aug 2023 11:02:52 -0700 Subject: [PATCH 3/7] Move sparse logabsdet from structured to sparsematrix --- src/rulesets/LinearAlgebra/structured.jl | 37 --------------------- src/rulesets/SparseArrays/sparsematrix.jl | 38 ++++++++++++++++++++++ test/rulesets/LinearAlgebra/structured.jl | 10 ------ test/rulesets/SparseArrays/sparsematrix.jl | 10 ++++++ 4 files changed, 48 insertions(+), 47 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 6fa6ad36b..da153d14e 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -268,40 +268,3 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) end return y, logdet_pullback end - -function rrule(::typeof(logabsdet), x::SparseMatrixCSC) - F = cholesky(x) - L, D, U, P = SparseInverseSubset.get_ldup(F) - Ω = logabsdet(D) - function logabsdet_pullback(ΔΩ) - (Δy, Δsigny) = ΔΩ - (_, signy) = Ω - f = signy' * Δsigny - imagf = f - real(f) - g = real(Δy) + imagf - Z, P = sparseinv(F, depermute=true) - ∂x = g * Z' - return (NoTangent(), ∂x) - end - return Ω, logabsdet_pullback -end - -function rrule(::typeof(logdet), x::SparseMatrixCSC) - Ω = logdet(x) - function logdet_pullback(ΔΩ) - Z, p = sparseinv(x, depermute=true) - ∂x = ΔΩ * Z' - return (NoTangent(), ∂x) - end - return Ω, logdet_pullback -end - -function rrule(::typeof(det), x::SparseMatrixCSC) - Ω = det(x) - function det_pullback(ΔΩ) - Z, _ = sparseinv(x, depermute=true) - ∂x = Z' * dot(Ω, ΔΩ) - return (NoTangent(), ∂x) - end - return Ω, det_pullback -end diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 51b41421c..03515156e 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -49,3 +49,41 @@ function rrule(::typeof(findnz), v::AbstractSparseVector) return (I, V), findnz_pullback end + + +function rrule(::typeof(logabsdet), x::SparseMatrixCSC) + F = cholesky(x) + L, D, U, P = SparseInverseSubset.get_ldup(F) + Ω = logabsdet(D) + function logabsdet_pullback(ΔΩ) + (Δy, Δsigny) = ΔΩ + (_, signy) = Ω + f = signy' * Δsigny + imagf = f - real(f) + g = real(Δy) + imagf + Z, P = sparseinv(F, depermute=true) + ∂x = g * Z' + return (NoTangent(), ∂x) + end + return Ω, logabsdet_pullback +end + +function rrule(::typeof(logdet), x::SparseMatrixCSC) + Ω = logdet(x) + function logdet_pullback(ΔΩ) + Z, p = sparseinv(x, depermute=true) + ∂x = ΔΩ * Z' + return (NoTangent(), ∂x) + end + return Ω, logdet_pullback +end + +function rrule(::typeof(det), x::SparseMatrixCSC) + Ω = det(x) + function det_pullback(ΔΩ) + Z, _ = sparseinv(x, depermute=true) + ∂x = Z' * dot(Ω, ΔΩ) + return (NoTangent(), ∂x) + end + return Ω, det_pullback +end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 530b4b7ad..1b83cc394 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -161,14 +161,4 @@ end end end - - @testset "[log[abs[det]]] SparseMatrixCSC" begin - ii = 1:5 - jj = 1:5 - x = ones(5) - A = sparse(ii, jj, x) - test_rrule(logabsdet, A) - test_rrule(logdet, A) - test_rrule(det, A) - end end diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index a11a1e963..39ff1a360 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -33,3 +33,13 @@ end V̄ = rand!(similar(V)) test_rrule(findnz, v ⊢ dv, output_tangent=(zeros(length(I)), V̄)) end + +@testset "[log[abs[det]]] SparseMatrixCSC" begin + ii = 1:5 + jj = 1:5 + x = ones(5) + A = sparse(ii, jj, x) + test_rrule(logabsdet, A) + test_rrule(logdet, A) + test_rrule(det, A) +end \ No newline at end of file From 72e6a83eaa9febc6f036386079e97ff85ce970cc Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Tue, 15 Aug 2023 11:08:56 -0700 Subject: [PATCH 4/7] make test matrix not just diagonal --- test/rulesets/SparseArrays/sparsematrix.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index 39ff1a360..03f1052c2 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -35,9 +35,9 @@ end end @testset "[log[abs[det]]] SparseMatrixCSC" begin - ii = 1:5 - jj = 1:5 - x = ones(5) + ii = [1:5; 2; 4] + jj = [1:5; 4; 2] + x = [ones(5); 0.1; 0.1] A = sparse(ii, jj, x) test_rrule(logabsdet, A) test_rrule(logdet, A) From 7df39de22f377e08efb4efd7e57681f73fa0fe58 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Tue, 15 Aug 2023 15:34:13 -0700 Subject: [PATCH 5/7] import logabsdetogabsdet(F::UmfpackLU) from future This method is required for the sparse logabsdet rrules, but was not included in Julia prior to v1.7. --- Project.toml | 1 + src/rulesets/SparseArrays/sparsematrix.jl | 34 +++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/Project.toml b/Project.toml index ada4203d2..79ea1b764 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0" diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 03515156e..7af68c540 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -50,6 +50,40 @@ function rrule(::typeof(findnz), v::AbstractSparseVector) return (I, V), findnz_pullback end +if VERSION < v"1.7" + #= + This method for `logabsdet(F::UmfpackLU)` is required to calculate the (log)determinants + of sparse matrices, but was not defined prior to Julia v1.7. In order fo the rrules + for the determinants of sparse matrices below to work, they need to be able to + compute the primals as well, so this import from the future is included. For more + recent versions of Julia, this definition lives in: + julia/stdlib/SuiteSparse/src/umfpack.jl + =# + using SuiteSparse.UMFPACK: _signperm, UmfpackLU + + for itype in (:Int32, :Int64) + @eval begin + function LinearAlgebra.logabsdet(F::UmfpackLU{T, $itype}) where {T<:Union{Float64,ComplexF64}} + n = checksquare(F) + issuccess(F) || return log(zero(real(T))), zero(T) + U = F.U + Rs = F.Rs + p = F.p + q = F.q + s = _signperm(p)*_signperm(q)*one(real(T)) + P = one(T) + abs_det = zero(real(T)) + @inbounds for i in 1:n + dg_ii = U[i, i] / Rs[i] + P *= sign(dg_ii) + abs_det += log(abs(dg_ii)) + end + return abs_det, s * P + end + end + end +end + function rrule(::typeof(logabsdet), x::SparseMatrixCSC) F = cholesky(x) From affad438caeb7a680e49c397b08fb25912338d7a Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Tue, 15 Aug 2023 16:08:36 -0700 Subject: [PATCH 6/7] add missing _signperm --- src/rulesets/SparseArrays/sparsematrix.jl | 32 ++++++++++++++++++----- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 7af68c540..4e599c517 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -52,15 +52,35 @@ end if VERSION < v"1.7" #= - This method for `logabsdet(F::UmfpackLU)` is required to calculate the (log)determinants - of sparse matrices, but was not defined prior to Julia v1.7. In order fo the rrules - for the determinants of sparse matrices below to work, they need to be able to - compute the primals as well, so this import from the future is included. For more - recent versions of Julia, this definition lives in: + The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log) + determinants of sparse matrices, but was not defined prior to Julia v1.7. In order + for the rrules for the determinants of sparse matrices below to work, they need to be + able to compute the primals as well, so this import from the future is included. For + more recent versions of Julia, this definition lives in: julia/stdlib/SuiteSparse/src/umfpack.jl =# - using SuiteSparse.UMFPACK: _signperm, UmfpackLU + using SuiteSparse.UMFPACK: UmfpackLU + # compute the sign/parity of a permutation + function _signperm(p) + n = length(p) + result = 0 + todo = trues(n) + while any(todo) + k = findfirst(todo) + todo[k] = false + result += 1 # increment element count + j = p[k] + while j != k + result += 1 # increment element count + todo[j] = false + j = p[j] + end + result += 1 # increment cycle count + end + return ifelse(isodd(result), -1, 1) + end + for itype in (:Int32, :Int64) @eval begin function LinearAlgebra.logabsdet(F::UmfpackLU{T, $itype}) where {T<:Union{Float64,ComplexF64}} From 85a83bea7d535a8cba8d8053a37a45d7fac259ee Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Tue, 15 Aug 2023 16:17:54 -0700 Subject: [PATCH 7/7] get rid of function def macro --- src/rulesets/SparseArrays/sparsematrix.jl | 36 ++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 4e599c517..8ee8f8cd0 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -80,27 +80,23 @@ if VERSION < v"1.7" end return ifelse(isodd(result), -1, 1) end - - for itype in (:Int32, :Int64) - @eval begin - function LinearAlgebra.logabsdet(F::UmfpackLU{T, $itype}) where {T<:Union{Float64,ComplexF64}} - n = checksquare(F) - issuccess(F) || return log(zero(real(T))), zero(T) - U = F.U - Rs = F.Rs - p = F.p - q = F.q - s = _signperm(p)*_signperm(q)*one(real(T)) - P = one(T) - abs_det = zero(real(T)) - @inbounds for i in 1:n - dg_ii = U[i, i] / Rs[i] - P *= sign(dg_ii) - abs_det += log(abs(dg_ii)) - end - return abs_det, s * P - end + + function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}} + n = checksquare(F) + issuccess(F) || return log(zero(real(T))), zero(T) + U = F.U + Rs = F.Rs + p = F.p + q = F.q + s = _signperm(p)*_signperm(q)*one(real(T)) + P = one(T) + abs_det = zero(real(T)) + @inbounds for i in 1:n + dg_ii = U[i, i] / Rs[i] + P *= sign(dg_ii) + abs_det += log(abs(dg_ii)) end + return abs_det, s * P end end