From 4f0f333f899593e1828a9ba232f597e8268dcb43 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 20 Nov 2024 23:42:38 -0500 Subject: [PATCH] Add crappy rule for dA in Sparse matmul (#2109) * Add crappy rule for dA * Make rule better and add tests --- src/internal_rules.jl | 40 +++++++++++++++++++++++++++------------- test/internal_rules.jl | 26 ++++++++++++++------------ 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 71c700e73a..04aca1a66a 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -733,7 +733,7 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(LinearAlgebra.mul!)}, ::Type{RT}, C::Annotation{<:StridedVecOrMat}, - A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + A::Annotation{<:SparseArrays.SparseMatrixCSCUnion}, B::Annotation{<:StridedVecOrMat}, α::Annotation{<:Number}, β::Annotation{<:Number} @@ -761,7 +761,10 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, && !(typeof(C) <: Const) ) ? copy(A.val) : nothing - # cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing + cache_B = ( EnzymeRules.overwritten(config)[6] + && !(typeof(A) <: Const) + && !(typeof(C) <: Const) + ) ? copy(B.val) : nothing if !isa(α, Const) cache_α = A.val*B.val @@ -769,7 +772,7 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, cache_α = nothing end - cache = (cache_C, cache_A, cache_α) + cache = (cache_C, cache_A, cache_B, cache_α) return EnzymeRules.AugmentedReturn(primal, shadow, cache) end @@ -778,16 +781,16 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(LinearAlgebra.mul!)}, ::Type{RT}, cache, C::Annotation{<:StridedVecOrMat}, - A::Const{<:SparseArrays.SparseMatrixCSCUnion}, + A::Annotation{<:SparseArrays.SparseMatrixCSCUnion}, B::Annotation{<:StridedVecOrMat}, α::Annotation{<:Number}, β::Annotation{<:Number} ) where {RT} - cache_C, cache_A, cache_α = cache + cache_C, cache_A, cache_B, cache_α = cache Cval = !isnothing(cache_C) ? cache_C : C.val Aval = !isnothing(cache_A) ? cache_A : A.val - # Bval = !isnothing(cache_B) ? cache_B : B.val + Bval = !isnothing(cache_B) ? cache_B : B.val N = EnzymeRules.width(config) if !isa(C, Const) @@ -821,13 +824,24 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig, end for i in 1:N - # This rule is incorrect since you need to project dA to have the same - # sparsity pattern as A. - # if !isa(A, Const) - # dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - # #dA .+= α*dC*B' - # mul!(dA, dC, Bval', α.val, true) - # end + if !isa(A, Const) + # dA .+= αdC*B' + # You need to be careful so that dA sparsity pattern does not change. Otherwise + # you will get incorrect gradients. So for now we do the slow and bad way of accumulating + dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[i] + dC = EnzymeRules.width(config) == 1 ? C.dval : C.dval[i] + # Now accumulate to preserve the correct sparsity pattern + I, J, _ = SparseArrays.findnz(dA) + for k in eachindex(I, J) + Ik, Jk = I[k], J[k] + tmp = zero(eltype(dA)) + for ti in axes(dC,2) + tmp += dC[Ik, ti]*Bval[Jk, ti] + end + dA[Ik, Jk] += α.val*tmp + end + # mul!(dA, dCs, Bval', α.val, true) + end if !isa(B, Const) #dB .+= α*A'*dC diff --git a/test/internal_rules.jl b/test/internal_rules.jl index a91ddaa620..ad10e88e79 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -721,17 +721,18 @@ end α = 2.0 β = 1.0 - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), Tα in (Const, Active), Tβ in (Const, Active) - are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue - test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ)) end - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) - are_activities_compatible(Tret, Tret, Tv) || continue + for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), + Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tret, TM, Tv) || continue test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) end @@ -740,8 +741,6 @@ end @test dα ≈ 0 @test dβ ≈ 0 - - end @testset "SparseArrays spmatmat reverse rule" begin @@ -751,15 +750,18 @@ end α = 2.0 β = 1.0 - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), + for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), Tα in (Const, Active), Tβ in (Const, Active) - are_activities_compatible(Tret, Tv, Tα, Tβ) || continue - test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) + are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue + test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ)) + end - for Tret in (Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) - are_activities_compatible(Tret, Tv) || continue + + for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), + Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false) + are_activities_compatible(Tret, Tret, TM, Tv) || continue test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) end