Skip to content

Commit

Permalink
new sliced einsum (inplace) (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu authored Apr 2, 2024
1 parent 8950973 commit 8c07022
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
27 changes: 23 additions & 4 deletions src/slicing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ function fill_slice!(x, ix, chunk, slicemap::Dict)
end
return x
end
function view_slice(x, ix, slicemap::Dict)
if ndims(x) == 0
return x
else
slices = map(l->haskey(slicemap, l) ? slicemap[l] : Colon(), ix)
return view(x, slices...)
end
end

function (se::SlicedEinsum{LT,ET})(@nospecialize(xs::AbstractArray...); size_info = nothing, kwargs...) where {LT, ET}
# get size
Expand All @@ -89,16 +97,27 @@ function (se::SlicedEinsum{LT,ET})(@nospecialize(xs::AbstractArray...); size_inf
return einsum(se, xs, size_dict; kwargs...)
end

function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), size_dict::Dict; kwargs...)
length(se.slicing) == 0 && return einsum(se.eins, xs, size_dict; kwargs...)
function einsum!(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), y, sx, sy, size_dict::Dict)
length(se.slicing) == 0 && return einsum!(se.eins, xs, y, sx, sy, size_dict)
iszero(sy) ? fill!(y, zero(eltype(y))) : rmul!(y, sy)
it = SliceIterator(se, size_dict)
eins_sliced = drop_slicedim(se.eins, se.slicing)
for slicemap in it
xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs))
einsum!(eins_sliced, xsi, view_slice(y, it.iyv, slicemap), sx, true, it.size_dict_sliced)
end
return y
end
function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} where N), size_dict::Dict)
length(se.slicing) == 0 && return einsum(se.eins, xs, size_dict)
it = SliceIterator(se, size_dict)
res = get_output_array(xs, getindex.(Ref(size_dict), it.iyv))
eins_sliced = drop_slicedim(se.eins, se.slicing)
for (k, slicemap) in enumerate(it)
for slicemap in it
# NOTE: @debug will break Zygote
# @debug "computing slice $k/$(length(it))"
xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs))
resi = einsum(eins_sliced, xsi, it.size_dict_sliced; kwargs...)
resi = einsum(eins_sliced, xsi, it.size_dict_sliced)
res = fill_slice!(res, it.iyv, resi, slicemap)
end
return res
Expand Down
6 changes: 5 additions & 1 deletion test/slicing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ end
@test getixsv(se) == [['i','j'],['j','k'],['k','l'],['l','m']]
@test getiyv(se) == ['i','m']
@test label_elimination_order(se) == ['j','l', 'k']
@test se(xs...) se.eins(xs...)
expected = se.eins(xs...)
@test se(xs...) expected
y = similar(se(xs...))
@test einsum!(se, xs, y, true, false, size_info) expected
@test y expected
@test uniquelabels(se) == ['i', 'j', 'k', 'l', 'm']
@test uniformsize(se, 2) == Dict(zip(['i', 'j', 'k', 'l', 'm'], ones(Int, 5).*2))
end

0 comments on commit 8c07022

Please sign in to comment.