Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs with SparseInverseSubset on non-GPL Julia builds #772

Merged
merged 3 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Structured matrices
using LinearAlgebra: AbstractTriangular
using SparseInverseSubset

# Matrix wrapper types that we know are square and are thus potentially invertible. For
# these we can use simpler definitions for `/` and `\`.
Expand Down
163 changes: 84 additions & 79 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,94 +50,99 @@
return (I, V), findnz_pullback
end

if VERSION < v"1.7"
#=
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
for the rrules for the determinants of sparse matrices below to work, they need to be
able to compute the primals as well, so this import from the future is included. For
more recent versions of Julia, this definition lives in:
julia/stdlib/SuiteSparse/src/umfpack.jl
=#
using SuiteSparse.UMFPACK: UmfpackLU

# compute the sign/parity of a permutation
function _signperm(p)
n = length(p)
result = 0
todo = trues(n)
while any(todo)
k = findfirst(todo)
todo[k] = false
result += 1 # increment element count
j = p[k]
while j != k
if Base.USE_GPL_LIBS # Don't define rrules for sparse determinants if we don't have CHOLMOD from SuiteSparse.jl

Check warning on line 54 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:54:-
if VERSION < v"1.7"
#=
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
for the rrules for the determinants of sparse matrices below to work, they need to be
able to compute the primals as well, so this import from the future is included. For
more recent versions of Julia, this definition lives in:
julia/stdlib/SuiteSparse/src/umfpack.jl
=#
using SuiteSparse.UMFPACK: UmfpackLU

Check warning on line 65 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:65:- src/rulesets/SparseArrays/sparsematrix.jl:74:+
# compute the sign/parity of a permutation
function _signperm(p)
n = length(p)
result = 0
todo = trues(n)
while any(todo)
k = findfirst(todo)
todo[k] = false
result += 1 # increment element count
todo[j] = false
j = p[j]
j = p[k]
while j != k
result += 1 # increment element count
todo[j] = false
j = p[j]
end
result += 1 # increment cycle count
end
result += 1 # increment cycle count
return ifelse(isodd(result), -1, 1)
end
return ifelse(isodd(result), -1, 1)
end

function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
n = checksquare(F)
issuccess(F) || return log(zero(real(T))), zero(T)
U = F.U
Rs = F.Rs
p = F.p
q = F.q
s = _signperm(p)*_signperm(q)*one(real(T))
P = one(T)
abs_det = zero(real(T))
@inbounds for i in 1:n
dg_ii = U[i, i] / Rs[i]
P *= sign(dg_ii)
abs_det += log(abs(dg_ii))
using SparseInverseSubset

Check warning on line 87 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:87:- src/rulesets/SparseArrays/sparsematrix.jl:88:- function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}} src/rulesets/SparseArrays/sparsematrix.jl:96:+ src/rulesets/SparseArrays/sparsematrix.jl:97:+ function LinearAlgebra.logabsdet( src/rulesets/SparseArrays/sparsematrix.jl:98:+ F::UmfpackLU{T,TI} src/rulesets/SparseArrays/sparsematrix.jl:99:+ ) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32,Int64}}
function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
n = checksquare(F)
issuccess(F) || return log(zero(real(T))), zero(T)
U = F.U
Rs = F.Rs
p = F.p
q = F.q
s = _signperm(p)*_signperm(q)*one(real(T))

Check warning on line 95 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:95:- s = _signperm(p)*_signperm(q)*one(real(T)) src/rulesets/SparseArrays/sparsematrix.jl:106:+ s = _signperm(p) * _signperm(q) * one(real(T))
P = one(T)
abs_det = zero(real(T))
@inbounds for i in 1:n
dg_ii = U[i, i] / Rs[i]
P *= sign(dg_ii)
abs_det += log(abs(dg_ii))
end
return abs_det, s * P
end
return abs_det, s * P
end
end


function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
F = cholesky(x)
L, D, U, P = SparseInverseSubset.get_ldup(F)
Ω = logabsdet(D)
function logabsdet_pullback(ΔΩ)
(Δy, Δsigny) = ΔΩ
(_, signy) = Ω
f = signy' * Δsigny
imagf = f - real(f)
g = real(Δy) + imagf
Z, P = sparseinv(F, depermute=true)
∂x = g * Z'
return (NoTangent(), ∂x)

Check warning on line 106 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:106:- src/rulesets/SparseArrays/sparsematrix.jl:107:- src/rulesets/SparseArrays/sparsematrix.jl:117:+

function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
F = cholesky(x)
L, D, U, P = SparseInverseSubset.get_ldup(F)
Ω = logabsdet(D)
function logabsdet_pullback(ΔΩ)
(Δy, Δsigny) = ΔΩ
(_, signy) = Ω
f = signy' * Δsigny
imagf = f - real(f)
g = real(Δy) + imagf
Z, P = sparseinv(F, depermute=true)

Check warning on line 118 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:118:- Z, P = sparseinv(F, depermute=true) src/rulesets/SparseArrays/sparsematrix.jl:128:+ Z, P = sparseinv(F; depermute=true)
∂x = g * Z'
return (NoTangent(), ∂x)
end
return Ω, logabsdet_pullback
end
return Ω, logabsdet_pullback
end

function rrule(::typeof(logdet), x::SparseMatrixCSC)
Ω = logdet(x)
function logdet_pullback(ΔΩ)
Z, p = sparseinv(x, depermute=true)
∂x = ΔΩ * Z'
return (NoTangent(), ∂x)

Check warning on line 124 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:124:- src/rulesets/SparseArrays/sparsematrix.jl:134:+
function rrule(::typeof(logdet), x::SparseMatrixCSC)
Ω = logdet(x)
function logdet_pullback(ΔΩ)
Z, p = sparseinv(x, depermute=true)

Check warning on line 128 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:128:- Z, p = sparseinv(x, depermute=true) src/rulesets/SparseArrays/sparsematrix.jl:138:+ Z, p = sparseinv(x; depermute=true)
∂x = ΔΩ * Z'
return (NoTangent(), ∂x)
end
return Ω, logdet_pullback
end
return Ω, logdet_pullback
end

function rrule(::typeof(det), x::SparseMatrixCSC)
Ω = det(x)
function det_pullback(ΔΩ)
Z, _ = sparseinv(x, depermute=true)
∂x = Z' * dot(Ω, ΔΩ)
return (NoTangent(), ∂x)

Check warning on line 134 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:134:- src/rulesets/SparseArrays/sparsematrix.jl:144:+
function rrule(::typeof(det), x::SparseMatrixCSC)
Ω = det(x)
function det_pullback(ΔΩ)
Z, _ = sparseinv(x, depermute=true)

Check warning on line 138 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:138:- Z, _ = sparseinv(x, depermute=true) src/rulesets/SparseArrays/sparsematrix.jl:148:+ Z, _ = sparseinv(x; depermute=true)
∂x = Z' * dot(Ω, ΔΩ)
return (NoTangent(), ∂x)
end
return Ω, det_pullback
end
return Ω, det_pullback
end


end # rrules that depend on CHOLMOD

function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)

Expand Down
18 changes: 10 additions & 8 deletions test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ end
test_rrule(findnz, v ⊢ dv, output_tangent=(zeros(length(I)), V̄))
end

@testset "[log[abs[det]]] SparseMatrixCSC" begin
ii = [1:5; 2; 4]
jj = [1:5; 4; 2]
x = [ones(5); 0.1; 0.1]
A = sparse(ii, jj, x)
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
if Base.USE_GPL_LIBS # these rrules don't work without CHOLMOD from SuiteSparse.jl
@testset "[log[abs[det]]] SparseMatrixCSC" begin
ii = [1:5; 2; 4]
jj = [1:5; 4; 2]
x = [ones(5); 0.1; 0.1]
A = sparse(ii, jj, x)
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end
end
Loading