Skip to content

Commit

Permalink
Merge pull request #772 from ElOceanografo/nogpl
Browse files Browse the repository at this point in the history
Fix bugs with SparseInverseSubset on non-GPL Julia builds
  • Loading branch information
oxinabox authored Jan 24, 2024
2 parents c2fd16f + cfc7060 commit 7a44f20
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 88 deletions.
1 change: 0 additions & 1 deletion src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Structured matrices
using LinearAlgebra: AbstractTriangular
using SparseInverseSubset

# Matrix wrapper types that we know are square and are thus potentially invertible. For
# these we can use simpler definitions for `/` and `\`.
Expand Down
163 changes: 84 additions & 79 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,94 +50,99 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
return (I, V), findnz_pullback
end

if VERSION < v"1.7"
#=
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
for the rrules for the determinants of sparse matrices below to work, they need to be
able to compute the primals as well, so this import from the future is included. For
more recent versions of Julia, this definition lives in:
julia/stdlib/SuiteSparse/src/umfpack.jl
=#
using SuiteSparse.UMFPACK: UmfpackLU

# compute the sign/parity of a permutation
function _signperm(p)
n = length(p)
result = 0
todo = trues(n)
while any(todo)
k = findfirst(todo)
todo[k] = false
result += 1 # increment element count
j = p[k]
while j != k
if Base.USE_GPL_LIBS # Don't define rrules for sparse determinants if we don't have CHOLMOD from SuiteSparse.jl

if VERSION < v"1.7"
#=
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
for the rrules for the determinants of sparse matrices below to work, they need to be
able to compute the primals as well, so this import from the future is included. For
more recent versions of Julia, this definition lives in:
julia/stdlib/SuiteSparse/src/umfpack.jl
=#
using SuiteSparse.UMFPACK: UmfpackLU

# compute the sign/parity of a permutation
function _signperm(p)
n = length(p)
result = 0
todo = trues(n)
while any(todo)
k = findfirst(todo)
todo[k] = false
result += 1 # increment element count
todo[j] = false
j = p[j]
j = p[k]
while j != k
result += 1 # increment element count
todo[j] = false
j = p[j]
end
result += 1 # increment cycle count
end
result += 1 # increment cycle count
return ifelse(isodd(result), -1, 1)
end
return ifelse(isodd(result), -1, 1)
end

function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
n = checksquare(F)
issuccess(F) || return log(zero(real(T))), zero(T)
U = F.U
Rs = F.Rs
p = F.p
q = F.q
s = _signperm(p)*_signperm(q)*one(real(T))
P = one(T)
abs_det = zero(real(T))
@inbounds for i in 1:n
dg_ii = U[i, i] / Rs[i]
P *= sign(dg_ii)
abs_det += log(abs(dg_ii))
using SparseInverseSubset

function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
n = checksquare(F)
issuccess(F) || return log(zero(real(T))), zero(T)
U = F.U
Rs = F.Rs
p = F.p
q = F.q
s = _signperm(p)*_signperm(q)*one(real(T))
P = one(T)
abs_det = zero(real(T))
@inbounds for i in 1:n
dg_ii = U[i, i] / Rs[i]
P *= sign(dg_ii)
abs_det += log(abs(dg_ii))
end
return abs_det, s * P
end
return abs_det, s * P
end
end


function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
F = cholesky(x)
L, D, U, P = SparseInverseSubset.get_ldup(F)
Ω = logabsdet(D)
function logabsdet_pullback(ΔΩ)
(Δy, Δsigny) = ΔΩ
(_, signy) = Ω
f = signy' * Δsigny
imagf = f - real(f)
g = real(Δy) + imagf
Z, P = sparseinv(F, depermute=true)
∂x = g * Z'
return (NoTangent(), ∂x)


function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
F = cholesky(x)
L, D, U, P = SparseInverseSubset.get_ldup(F)
Ω = logabsdet(D)
function logabsdet_pullback(ΔΩ)
(Δy, Δsigny) = ΔΩ
(_, signy) = Ω
f = signy' * Δsigny
imagf = f - real(f)
g = real(Δy) + imagf
Z, P = sparseinv(F, depermute=true)
∂x = g * Z'
return (NoTangent(), ∂x)
end
return Ω, logabsdet_pullback
end
return Ω, logabsdet_pullback
end

function rrule(::typeof(logdet), x::SparseMatrixCSC)
Ω = logdet(x)
function logdet_pullback(ΔΩ)
Z, p = sparseinv(x, depermute=true)
∂x = ΔΩ * Z'
return (NoTangent(), ∂x)

function rrule(::typeof(logdet), x::SparseMatrixCSC)
Ω = logdet(x)
function logdet_pullback(ΔΩ)
Z, p = sparseinv(x, depermute=true)
∂x = ΔΩ * Z'
return (NoTangent(), ∂x)
end
return Ω, logdet_pullback
end
return Ω, logdet_pullback
end

function rrule(::typeof(det), x::SparseMatrixCSC)
Ω = det(x)
function det_pullback(ΔΩ)
Z, _ = sparseinv(x, depermute=true)
∂x = Z' * dot(Ω, ΔΩ)
return (NoTangent(), ∂x)

function rrule(::typeof(det), x::SparseMatrixCSC)
Ω = det(x)
function det_pullback(ΔΩ)
Z, _ = sparseinv(x, depermute=true)
∂x = Z' * dot(Ω, ΔΩ)
return (NoTangent(), ∂x)
end
return Ω, det_pullback
end
return Ω, det_pullback
end


end # rrules that depend on CHOLMOD

function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)

Expand Down
18 changes: 10 additions & 8 deletions test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ end
test_rrule(findnz, v dv, output_tangent=(zeros(length(I)), V̄))
end

@testset "[log[abs[det]]] SparseMatrixCSC" begin
ii = [1:5; 2; 4]
jj = [1:5; 4; 2]
x = [ones(5); 0.1; 0.1]
A = sparse(ii, jj, x)
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
if Base.USE_GPL_LIBS # these rrules don't work without CHOLMOD from SuiteSparse.jl
@testset "[log[abs[det]]] SparseMatrixCSC" begin
ii = [1:5; 2; 4]
jj = [1:5; 4; 2]
x = [ones(5); 0.1; 0.1]
A = sparse(ii, jj, x)
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end
end

0 comments on commit 7a44f20

Please sign in to comment.