Skip to content

Commit

Permalink
Merge pull request #730 from ElOceanografo/sparsedet
Browse files Browse the repository at this point in the history
det, logdet, and logabsdet rrules for SparseMatrixCSC
  • Loading branch information
oxinabox authored Aug 16, 2023
2 parents df672c3 + 85a83be commit 7a9feab
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
Adapt = "3.4.0"
Expand Down
1 change: 1 addition & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Structured matrices
using LinearAlgebra: AbstractTriangular
using SparseInverseSubset

# Matrix wrapper types that we know are square and are thus potentially invertible. For
# these we can use simpler definitions for `/` and `\`.
Expand Down
88 changes: 88 additions & 0 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,91 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)

return (I, V), findnz_pullback
end

if VERSION < v"1.7"
#=
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
for the rrules for the determinants of sparse matrices below to work, they need to be
able to compute the primals as well, so this import from the future is included. For
more recent versions of Julia, this definition lives in:
julia/stdlib/SuiteSparse/src/umfpack.jl
=#
using SuiteSparse.UMFPACK: UmfpackLU

# compute the sign/parity of a permutation
function _signperm(p)
n = length(p)
result = 0
todo = trues(n)
while any(todo)
k = findfirst(todo)
todo[k] = false
result += 1 # increment element count
j = p[k]
while j != k
result += 1 # increment element count
todo[j] = false
j = p[j]
end
result += 1 # increment cycle count
end
return ifelse(isodd(result), -1, 1)
end

function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
n = checksquare(F)
issuccess(F) || return log(zero(real(T))), zero(T)
U = F.U
Rs = F.Rs
p = F.p
q = F.q
s = _signperm(p)*_signperm(q)*one(real(T))
P = one(T)
abs_det = zero(real(T))
@inbounds for i in 1:n
dg_ii = U[i, i] / Rs[i]
P *= sign(dg_ii)
abs_det += log(abs(dg_ii))
end
return abs_det, s * P
end
end


function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
F = cholesky(x)
L, D, U, P = SparseInverseSubset.get_ldup(F)
Ω = logabsdet(D)
function logabsdet_pullback(ΔΩ)
(Δy, Δsigny) = ΔΩ
(_, signy) = Ω
f = signy' * Δsigny
imagf = f - real(f)
g = real(Δy) + imagf
Z, P = sparseinv(F, depermute=true)
∂x = g * Z'
return (NoTangent(), ∂x)
end
return Ω, logabsdet_pullback
end

function rrule(::typeof(logdet), x::SparseMatrixCSC)
Ω = logdet(x)
function logdet_pullback(ΔΩ)
Z, p = sparseinv(x, depermute=true)
∂x = ΔΩ * Z'
return (NoTangent(), ∂x)
end
return Ω, logdet_pullback
end

function rrule(::typeof(det), x::SparseMatrixCSC)
Ω = det(x)
function det_pullback(ΔΩ)
Z, _ = sparseinv(x, depermute=true)
∂x = Z' * dot(Ω, ΔΩ)
return (NoTangent(), ∂x)
end
return Ω, det_pullback
end
10 changes: 10 additions & 0 deletions test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ end
= rand!(similar(V))
test_rrule(findnz, v dv, output_tangent=(zeros(length(I)), V̄))
end

@testset "[log[abs[det]]] SparseMatrixCSC" begin
ii = [1:5; 2; 4]
jj = [1:5; 4; 2]
x = [ones(5); 0.1; 0.1]
A = sparse(ii, jj, x)
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end

0 comments on commit 7a9feab

Please sign in to comment.