From d4f6400ad114abf65205176db8326f73cbb5f4bf Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 29 Jan 2024 02:09:19 +0100 Subject: [PATCH] Restrict `ldiv!` rules to `Cholesky` (#1257) * Restrict `ldiv!` rules to `Cholesky` * Apply suggestions from code review * Update src/internal_rules.jl * Apply suggestions from code review --- src/internal_rules.jl | 43 ++++++++++++++----------------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 06416d0c69..adf612d6a9 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -621,10 +621,10 @@ end function EnzymeRules.forward( func::Const{typeof(ldiv!)}, RT::Type, - fact::Annotation{C}, - B, + fact::Annotation{<:Cholesky}, + B; kwargs... -) where {C <: Union{Cholesky,Array}} +) if isa(B, Const) @assert (RT <: Const) return func.val(fact.val, B.val; kwargs...) @@ -656,19 +656,11 @@ function EnzymeRules.forward( fact.dval[b] end - if C <: Array - mul!(dB, dfact, retval, -1, 1) - else - tmp = dfact.U * retval - - dB .-= dfact.L * tmp - - # if mul! was implemented for LU, this would be faster - # mul!(dB, dfact.L, tmp, -1, 1) - end + tmp = dfact.U * retval + mul!(dB, dfact.L, tmp, -1, 1) end - ldiv!(fact.val, dB; kwargs...) + func.val(fact.val, dB; kwargs...) end if RT <: Const @@ -750,12 +742,11 @@ function EnzymeRules.augmented_primal( func::Const{typeof(ldiv!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, - A::Annotation{AType}, + A::Annotation{<:Cholesky}, B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; kwargs... -) where {AType <: Union{Cholesky, Array}} - - ldiv!(A.val, B.val; kwargs...) +) + func.val(A.val, B.val; kwargs...) cache_Bout = if !isa(A, Const) && !isa(B, Const) if EnzymeRules.overwritten(config)[3] @@ -797,11 +788,10 @@ function EnzymeRules.reverse( func::Const{typeof(ldiv!)}, dret, cache, - A::Annotation{AType}, - B; + A::Annotation{<:Cholesky}, + B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; kwargs... -) where {AType <: Union{Cholesky,Array}} - +) if !isa(B, Const) (cache_A, cache_Bout) = cache @@ -813,15 +803,10 @@ function EnzymeRules.reverse( # dB = z, where z = inv(A^T) dB # dA −= z B(out)^T - func.val(cache_A, dB, kwargs...) + func.val(cache_A, dB; kwargs...) if !isa(A, Const) dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - - if AType <: Array - mul!(dA, dB, transpose(cache_Bout), -1, 1) - else - mul!(dA.factors, dB, transpose(cache_Bout), -1, 1) - end + mul!(dA.factors, dB, transpose(cache_Bout), -1, 1) end end end