Skip to content

Commit

Permalink
Add crappy rule for dA in Sparse matmul (#2109)
Browse files Browse the repository at this point in the history
* Add crappy rule for dA

* Make rule better and add tests
  • Loading branch information
ptiede authored Nov 21, 2024
1 parent 665cebd commit 4f0f333
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
40 changes: 27 additions & 13 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT},
C::Annotation{<:StridedVecOrMat},
A::Const{<:SparseArrays.SparseMatrixCSCUnion},
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
Expand Down Expand Up @@ -761,15 +761,18 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig,
&& !(typeof(C) <: Const)
) ? copy(A.val) : nothing

# cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing
cache_B = ( EnzymeRules.overwritten(config)[6]
&& !(typeof(A) <: Const)
&& !(typeof(C) <: Const)
) ? copy(B.val) : nothing

if !isa(α, Const)
cache_α = A.val*B.val
else
cache_α = nothing
end

cache = (cache_C, cache_A, cache_α)
cache = (cache_C, cache_A, cache_B, cache_α)

return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
Expand All @@ -778,16 +781,16 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
func::Const{typeof(LinearAlgebra.mul!)},
::Type{RT}, cache,
C::Annotation{<:StridedVecOrMat},
A::Const{<:SparseArrays.SparseMatrixCSCUnion},
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
B::Annotation{<:StridedVecOrMat},
α::Annotation{<:Number},
β::Annotation{<:Number}
) where {RT}

cache_C, cache_A, cache_α = cache
cache_C, cache_A, cache_B, cache_α = cache
Cval = !isnothing(cache_C) ? cache_C : C.val
Aval = !isnothing(cache_A) ? cache_A : A.val
# Bval = !isnothing(cache_B) ? cache_B : B.val
Bval = !isnothing(cache_B) ? cache_B : B.val

N = EnzymeRules.width(config)
if !isa(C, Const)
Expand Down Expand Up @@ -821,13 +824,24 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
end

for i in 1:N
# This rule is incorrect since you need to project dA to have the same
# sparsity pattern as A.
# if !isa(A, Const)
# dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]
# #dA .+= α*dC*B'
# mul!(dA, dC, Bval', α.val, true)
# end
if !isa(A, Const)
# dA .+= αdC*B'
# You need to be careful so that dA sparsity pattern does not change. Otherwise
# you will get incorrect gradients. So for now we do the slow and bad way of accumulating
dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[i]
dC = EnzymeRules.width(config) == 1 ? C.dval : C.dval[i]
# Now accumulate to preserve the correct sparsity pattern
I, J, _ = SparseArrays.findnz(dA)
for k in eachindex(I, J)
Ik, Jk = I[k], J[k]
tmp = zero(eltype(dA))
for ti in axes(dC,2)
tmp += dC[Ik, ti]*Bval[Jk, ti]
end
dA[Ik, Jk] += α.val*tmp
end
# mul!(dA, dCs, Bval', α.val, true)
end

if !isa(B, Const)
#dB .+= α*A'*dC
Expand Down
26 changes: 14 additions & 12 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,17 +721,18 @@ end
α = 2.0
β = 1.0

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
in (Const, Active), Tβ in (Const, Active)

are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ))
are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ))

end


for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tret, Tv) || continue
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated),
Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tret, TM, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end

Expand All @@ -740,8 +741,6 @@ end
@test 0
@test 0



end

@testset "SparseArrays spmatmat reverse rule" begin
Expand All @@ -751,15 +750,18 @@ end
α = 2.0
β = 1.0

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
in (Const, Active), Tβ in (Const, Active)

are_activities_compatible(Tret, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ))
are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ))

end

for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tv) || continue

for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated),
Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
are_activities_compatible(Tret, Tret, TM, Tv) || continue
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
end

Expand Down

0 comments on commit 4f0f333

Please sign in to comment.