Skip to content

Commit

Permalink
[WIP] First stab at Cholesky (#1220)
Browse files Browse the repository at this point in the history
* First stab at Cholesky

* Fix

* fixup

* Adding files and minor fix

* Reverse test

* Fix

* Remove files

* Moving rules to internal_rules.jl

* Move Cholesky tests to the tests folder

* Reverse tests pass

* Remove development files

* Remove dev deps

* Fixes

* Batched modes

* Forward rule fix

* Fix

* Fix type ambiguity

* Adding preliminary ldiv!

* Inplace ldiv!

* Resolved ambiguities

* Zip objects

* Deduplicate forward

* Fix tests

* Fix

* cleanup tests

---------

Co-authored-by: Billy Moses <[email protected]>
  • Loading branch information
michel2323 and wsmoses authored Jan 26, 2024
1 parent 922bb89 commit 1adc514
Show file tree
Hide file tree
Showing 3 changed files with 490 additions and 6 deletions.
309 changes: 305 additions & 4 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ end
@inline EnzymeRules.inactive_type(v::Type{T}) where {T<:Module} = true
@inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true

@inline width(::Duplicated) = 1
@inline width(::BatchDuplicated{T, N}) where {T, N} = N
@inline width(::DuplicatedNoNeed) = 1
@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N

# Note all of these forward mode definitions do not support runtime activity as
# the do not keep the primal if shadow(x.y) == primal(x.y)
function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated)
Expand Down Expand Up @@ -306,10 +311,10 @@ end
end

# y=inv(A) B
# dA −= z y^T
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array}

cache_A = if EnzymeRules.overwritten(config)[2]
copy(A.val)
else
Expand Down Expand Up @@ -346,7 +351,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}
else
nothing
end

@static if VERSION < v"1.8.0"
UT = Union{
LinearAlgebra.Diagonal{eltype(AT), BT},
Expand Down Expand Up @@ -449,7 +454,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple}
function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple}
nr, nc = size(out.val,1), size(out.val,2)
for b in 1:EnzymeRules.width(config)
da = if EnzymeRules.width(config) == 1
Expand Down Expand Up @@ -558,3 +563,299 @@ function EnzymeRules.reverse(
xs.dval .= xs.dval[back_inds]
return (nothing,)
end

function EnzymeRules.forward(
::Const{typeof(cholesky)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
A::Union{Const, Duplicated};
kwargs...
)
fact = cholesky(A.val; kwargs...)
if RT <: Const
return fact
end
# TODO: This will be a problem for sparse matrices as invL and dL are dense
invL = inv(fact.L)
# TODO: dL is dense even when L was sparse
dL = Matrix(fact.L * LowerTriangular(invL * A.dval * invL' * 0.5 * I))
# TODO: Stored as Cholesky, although it isn't a Cholesky factorization
dfact = Cholesky(dL, 'L', 0)
if RT <: DuplicatedNoNeed
return dfact
else
return Duplicated(fact, dfact)
end
end

function EnzymeRules.forward(
::Const{typeof(cholesky)},
RT::Type{<:Union{BatchDuplicatedNoNeed, BatchDuplicated}},
A::Union{BatchDuplicatedNoNeed{T,N}, BatchDuplicated{T,N}};
kwargs...
) where {T,N}
fact = cholesky(A.val; kwargs...)
if RT <: Const
return fact
end
invL = inv(fact.L)
dfact = ntuple(
i-> Cholesky(
Matrix(fact.L * LowerTriangular(invL * A.dval[i] * invL' * 0.5 * I)), 'L', 0
), Val(N)
)
if RT <: BatchDuplicatedNoNeed
return dfact
else
return BatchDuplicated(fact, dfact)
end
end

function EnzymeRules.forward(
func::Const{typeof(\)},
RT::Type{<:Union{Const, Duplicated, DuplicatedNoNeed, BatchDuplicatedNoNeed, BatchDuplicated}},
fact::Annotation{C},
B::Union{Const, Duplicated, BatchDuplicated};
kwargs...
) where {C <: Cholesky}
retval = copy(B.val)
ldiv!(fact.val, retval)
N = RT <: BatchDuplicated ? RT.parameters[2] : 1
dfact = if RT <: BatchDuplicated
fact.dval
else
(fact.dval,)
end
dB = if RT <: BatchDuplicated
B.dval
else
(B.dval,)
end
bretdval = ntuple(Val(N)) do b
if isa(fact, Const) && isa(B, Const)
nothing
elseif isa(B, Const)
retdval = zeros(length(retval))
mul!(retdval, A, B.val)
mul!(retdval, -1, retdval)
ldiv!(fact.val, retdval)
elseif isa(fact, Const)
retdval = copy(B.dval)
ldiv!(fact.val, retdval)
else
retdval = zeros(length(retval))
# mul!(retdval, fact.dval[b].U, retval)
# mul!(retdval, fact.dval[b].L, retdval)
mul!(retdval, dfact[b].U, retval)
mul!(retdval, dfact[b].L, retdval)
retdval .= dB[b] .- retdval
ldiv!(fact.val, retdval)
end
retdval
end
if RT <: Const
return retval
elseif (RT <: DuplicatedNoNeed) || (RT <: BatchDuplicatedNoNeed)
return retdval
elseif RT <: Duplicated
return Duplicated(retval, bretdval[1])
else
return BatchDuplicated(retval, ntuple(i -> bretdval[i], Val(N)))
end
end

function EnzymeRules.forward(
func::Const{typeof(ldiv!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}},
fact::Annotation{C},
B::Union{Const, Duplicated, BatchDuplicated};
kwargs...
) where {C <: Cholesky}
retval = B.val
retval = ldiv!(fact.val, retval)
dfact = width(fact) == 1 ? (fact.dval,) : fact.dval
dB = width(B) == 1 ? (B.dval,) : B.dval
N = max(width(fact), width(B))
for b in 1:N
if isa(fact, Const) && isa(B, Const)
nothing
elseif isa(B, Const)
retdval = zeros(length(retval))
mul!(retdval, A, B.val)
mul!(retdval, -1, retdval)
ldiv!(fact.val, retdval)
B.val .= retval
dB[b] .= retdval
elseif isa(fact, Const)
retdval = copy(B.dval)
ldiv!(fact.val, retdval)
B.val .= retval
dB[b] .= retdval
else
retdval = zeros(length(retval))
mul!(retdval, dfact[b].U, retval)
mul!(retdval, dfact[b].L, retdval)
retdval .= dB[b] .- retdval
ldiv!(fact.val, retdval)
B.val .= retval
dB[b] .= retdval
end
retdval
end
if RT <: Const
return retval
elseif (RT <: DuplicatedNoNeed) || (RT <: BatchDuplicatedNoNeed)
return retdval
elseif RT <: Duplicated
return Duplicated(retval, dB[1])
else
return BatchDuplicated(retval, ntuple(i -> dB[i], Val(N)))
end
end

function EnzymeRules.augmented_primal(
config,
func::Const{typeof(cholesky)},
RT::Type{<:Union{Const, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}},
A::Union{Const, Duplicated, BatchDuplicated};
kwargs...
)
fact = if EnzymeRules.needs_primal(config)
cholesky(A.val; kwargs...)
else
nothing
end
# dfact would be a dense matrix, prepare buffer
dfact = if EnzymeRules.width(config) == 1
Cholesky(Matrix(fact), 'L', 0)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
Cholesky(Matrix(fact), 'L', 0)
end
end
return EnzymeRules.AugmentedReturn(fact, dfact, (dfact,))
end

function EnzymeRules.reverse(
config,
::Const{typeof(cholesky)},
dret,
cache,
A;
kwargs...
)
(dfact,) = cache
dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval
dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact

for (dA, dfact) in zip(dAs, dfacts)
dA .+= dfact.factors
end
return (nothing,)
end

function EnzymeRules.augmented_primal(
config,
func::Const{typeof(\)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}},

fact::Annotation{C},
B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated};
kwargs...
) where {C <: Cholesky}
x = copy(B.val)
primal = if EnzymeRules.needs_primal(config)
x
else
nothing
end
ldiv!(fact.val, x)
nobatched = EnzymeRules.width(config) == 1 ? true : false
shadow = nobatched ? zeros(size(B.val)) :
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zeros(size(B.val))
end
buffer = nobatched ? zeros(size(B.val)) :
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zeros(size(B.val))
end
return EnzymeRules.AugmentedReturn(primal, shadow, (x, shadow, buffer))
end

function EnzymeRules.reverse(
config,
func::Union{Const{typeof(\)}},
dret,
cache,
fact::Annotation{C},
B::Annotation;
kwargs...
) where {C <: Cholesky}

(x, dx, buffer) = cache

dxs = EnzymeRules.width(config) == 1 ? (dx,) : dx
buffers = EnzymeRules.width(config) == 1 ? (buffer,) : buffer
dfacts = EnzymeRules.width(config) == 1 ? (fact.dval,) : fact.dval
dBs = EnzymeRules.width(config) == 1 ? (B.dval,) : B.dval

for (dx, buffer, dfact, dB) in zip(dxs, buffers, dfacts, dBs)
buffer .= fact.val\dx
dB .+= buffer
buffer = reshape(buffer, size(buffer,2), size(buffer,1))
dfact.factors .+= -x .* buffer
end
return (nothing, nothing)
end

function EnzymeRules.augmented_primal(
config,
func::Const{typeof(ldiv!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}},

fact::Annotation{C},
B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated};
kwargs...
) where {C <: Cholesky}
x = B.val
primal = if EnzymeRules.needs_primal(config)
x
else
nothing
end
ldiv!(fact.val, x)
nobatched = EnzymeRules.width(config) == 1 ? true : false
shadow = nothing
buffer = nobatched ? zeros(size(B.val)) :
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zeros(size(B.val))
end
return EnzymeRules.AugmentedReturn(primal, shadow, (x, shadow, buffer))
end

function EnzymeRules.reverse(
config,
func::Const{typeof(ldiv!)},
dret,
cache,
_fact::Annotation{C},
_B::Annotation;
kwargs...
) where {C <: Cholesky}

(x, _dx, _buffer) = cache
for b in 1:EnzymeRules.width(config)
buffer = EnzymeRules.width(config) == 1 ? _buffer : _buffer[b]
dfact = EnzymeRules.width(config) == 1 ? _fact.dval : _fact.dval[b]
dB = EnzymeRules.width(config) == 1 ? _B.dval : _B.dval[b]

buffer .= _fact.val\dB
dB .+= buffer
buffer = reshape(buffer, size(buffer,2), size(buffer,1))
dfact.factors .+= -x .* buffer
end
return (nothing, nothing)
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Expand Down
Loading

0 comments on commit 1adc514

Please sign in to comment.