From 8c07022deaf5b34328792033aef7cb290f2d75c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jinguo=20Liu=20=28=E5=88=98=E9=87=91=E5=9B=BD=29?= Date: Wed, 3 Apr 2024 02:29:58 +0800 Subject: [PATCH] new sliced einsum (inplace) (#164) --- src/slicing.jl | 27 +++++++++++++++++++++++---- test/slicing.jl | 6 +++++- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/slicing.jl b/src/slicing.jl index c2804eb..e44e89c 100644 --- a/src/slicing.jl +++ b/src/slicing.jl @@ -80,6 +80,14 @@ function fill_slice!(x, ix, chunk, slicemap::Dict) end return x end +function view_slice(x, ix, slicemap::Dict) + if ndims(x) == 0 + return x + else + slices = map(l->haskey(slicemap, l) ? slicemap[l] : Colon(), ix) + return view(x, slices...) + end +end function (se::SlicedEinsum{LT,ET})(@nospecialize(xs::AbstractArray...); size_info = nothing, kwargs...) where {LT, ET} # get size @@ -89,16 +97,27 @@ function (se::SlicedEinsum{LT,ET})(@nospecialize(xs::AbstractArray...); size_inf return einsum(se, xs, size_dict; kwargs...) end -function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), size_dict::Dict; kwargs...) - length(se.slicing) == 0 && return einsum(se.eins, xs, size_dict; kwargs...) +function einsum!(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), y, sx, sy, size_dict::Dict) + length(se.slicing) == 0 && return einsum!(se.eins, xs, y, sx, sy, size_dict) + iszero(sy) ? fill!(y, zero(eltype(y))) : rmul!(y, sy) + it = SliceIterator(se, size_dict) + eins_sliced = drop_slicedim(se.eins, se.slicing) + for slicemap in it + xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs)) + einsum!(eins_sliced, xsi, view_slice(y, it.iyv, slicemap), sx, true, it.size_dict_sliced) + end + return y +end +function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), size_dict::Dict) + length(se.slicing) == 0 && return einsum(se.eins, xs, size_dict) it = SliceIterator(se, size_dict) res = get_output_array(xs, getindex.(Ref(size_dict), it.iyv)) eins_sliced = drop_slicedim(se.eins, se.slicing) - for (k, slicemap) in enumerate(it) + for slicemap in it # NOTE: @debug will break Zygote # @debug "computing slice $k/$(length(it))" xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs)) - resi = einsum(eins_sliced, xsi, it.size_dict_sliced; kwargs...) + resi = einsum(eins_sliced, xsi, it.size_dict_sliced) res = fill_slice!(res, it.iyv, resi, slicemap) end return res diff --git a/test/slicing.jl b/test/slicing.jl index d2b98f2..6314615 100644 --- a/test/slicing.jl +++ b/test/slicing.jl @@ -20,7 +20,11 @@ end @test getixsv(se) == [['i','j'],['j','k'],['k','l'],['l','m']] @test getiyv(se) == ['i','m'] @test label_elimination_order(se) == ['j','l', 'k'] - @test se(xs...) ≈ se.eins(xs...) + expected = se.eins(xs...) + @test se(xs...) ≈ expected + y = similar(se(xs...)) + @test einsum!(se, xs, y, true, false, size_info) ≈ expected + @test y ≈ expected @test uniquelabels(se) == ['i', 'j', 'k', 'l', 'm'] @test uniformsize(se, 2) == Dict(zip(['i', 'j', 'k', 'l', 'm'], ones(Int, 5).*2)) end \ No newline at end of file