Skip to content

Commit

Permalink
add @Ein! (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArrogantGao authored Feb 27, 2024
1 parent 6968324 commit 8950973
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/OMEinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using OMEinsumContractionOrders
using AbstractTrees
import LinearAlgebra: BlasFloat

export @ein_str, @ein, ein
export @ein_str, @ein, @ein!, ein
export einsum!, einsum, dynamic_einsum
export EinCode, EinIndexer, EinArray, DynamicEinCode, StaticEinCode, AbstractEinsum, NestedEinsum, SlicedEinsum, DynamicNestedEinsum, StaticNestedEinsum
export getiyv, getixsv, uniquelabels, labeltype
Expand Down
2 changes: 1 addition & 1 deletion src/einsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function einsum(code::AbstractEinsum, @nospecialize(xs::Tuple), size_dict::Dict
end

# inplace einsum, EinCode as the input
function einsum!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(y), sx, sy, size_dict::Dict)
function einsum!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(y), sx, sy, size_dict::Dict = get_size_dict(getixs(code), xs))
einsum!(getixs(code), getiy(code), xs, y, sx, sy, size_dict)
end
# inplace einsum, the fallback
Expand Down
74 changes: 71 additions & 3 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ macro ein(exs...)
_ein_macro(exs...)
end


primefix!(ind) = map!(i -> @capture(i, (j_)') ? Symbol(j, '') : i, ind, ind)

function _ein_macro(ex; einsum=:einsum)
@capture(ex, (left_ := right_)) || throw(ArgumentError("expected A[] := B[]... "))
function _ein_macro(ex; einsum = :einsum)
@capture(ex, (left_ := right_)) || throw(ArgumentError("expected @ein A[] := B[]..."))

@capture(left, Z_[leftind__] | [leftind__] ) || throw(
ArgumentError("can't understand LHS, expected A[i,j] etc."))
if Z===nothing
Expand All @@ -161,4 +161,72 @@ function _ein_macro(ex; einsum=:einsum)
rightnames = [ esc(A) for (A, ind) in rightpairs ]

return :( $(esc(Z)) = $einsum( EinCode(($(righttuples...),), $lefttuple), ($(rightnames...),)) )
end

"""
@ein! A[i,k] := B[i,j] * C[j,k] # A = B * C
@ein! A[i,k] += B[i,j] * C[j,k] # A += B * C
Macro interface similar to that of other packages.
Inplace version of `@ein`.
# example
```jldoctest; setup = :(using OMEinsum)
julia> a, b, c, d = rand(2,2), rand(2,2), rand(2,2), zeros(2,2);
julia> cc = copy(c);
julia> @ein! d[i,k] := a[i,j] * b[j,k];
julia> d ≈ a * b
true
julia> d ≈ ein"ij,jk -> ik"(a,b)
true
julia> @ein! c[i,k] += a[i,j] * b[j,k];
julia> c ≈ cc + a * b
true
```
"""
macro ein!(exs...)
_ein_macro!(exs...)
end

function _ein_macro!(ex; einsum = :einsum!)
if @capture(ex, (left_ := right_))
flag = false
elseif @capture(ex, (left_ += right_))
flag = true
else
throw(ArgumentError("expected @ein! A[] := B[]... or @ein! A[] += B[]..."))
end

@capture(left, Z_[leftind__] | [leftind__] ) || throw(
ArgumentError("can't understand LHS, expected A[i,j] etc."))
if Z===nothing
throw(ArgumentError("LHS is needed for inplace einsum, expected A[i,j] etc."))
end
primefix!(leftind)

rightind, rightpairs = [], []
@capture(right, *(factors__)) || (factors = Any[right])
for fact in factors
@capture(fact, A_[Aind__]) || return _nested_ein_macro(ex)
primefix!(Aind)
append!(rightind, Aind)
push!(rightpairs, (A, Aind) )
end
unique!(rightind)
isempty(setdiff(leftind, rightind)) || throw(
ArgumentError("some indices appear only on the left"))

lefttuple = Tuple(indexin(leftind, rightind))
righttuples = [ Tuple(indexin(ind, rightind)) for (A, ind) in rightpairs ]
rightnames = [ esc(A) for (A, ind) in rightpairs ]

return :( $(esc(Z)) = $einsum( EinCode(($(righttuples...),), $lefttuple), ($(rightnames...),), $(esc(Z)), true, $flag) )
end
14 changes: 13 additions & 1 deletion test/einsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ end
#@test_throws ArgumentError einsum(ein"ij,123 -> k", (a,a))
end

@testset "macro input" begin
@testset "non inplace macro input" begin
a = randn(2,2)
@test a * a @ein [i,k] := a[i,j] * a[j,k]
@test sum(a[i,i] for i in 1:2) (@ein [] := a[i,i])[]
Expand All @@ -252,6 +252,18 @@ end
@test permutedims(a) @ein [α,1] := a[1,α]
end

@testset "inplace macro input" begin
a = randn(2,2)
b = randn(2,2)
c = randn(2,2)
t = randn(2,2)
cc = copy(c)
@ein! t[i,k] := a[i,j] * b[j,k]
@ein! c[i,k] += a[i,j] * b[j,k]
@test a * b t
@test cc + a * b c
end

@testset "argument checks" begin
@test_throws ArgumentError einsum(ein"ij,jk -> ik", (rand(2,2), rand(2,2), rand(2,2)))
@test_throws ArgumentError einsum(ein"ij,jk,k -> ik", (rand(2,2), rand(2,2)))
Expand Down

0 comments on commit 8950973

Please sign in to comment.