Skip to content

Commit

Permalink
Reimplement automul
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Nov 25, 2024
1 parent 76ac027 commit 7a3b9c6
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 8 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ julia = "1.6"
[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Quantics = "87f76fb3-a40a-40c9-a63c-29fcfe7b7547"

[targets]
test = ["Test", "Random"]
test = ["Test", "Random", "Quantics"]
11 changes: 7 additions & 4 deletions src/ProjMPSs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ module ProjMPSs

import OrderedCollections: OrderedSet, OrderedDict
using EllipsisNotation
import LinearAlgebra
using LinearAlgebra: LinearAlgebra

import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds
import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds
import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds, uniqueinds
import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds, findsites
import ITensors.TagSets: hastag, hastags

import FastMPOContractions as FMPOC


include("util.jl")
include("projector.jl")
include("projmps.jl")
Expand All @@ -18,4 +18,7 @@ include("patching.jl")
include("contract.jl")
include("adaptivemul.jl")

# Only for backward compatibility
include("automul.jl")

end
124 changes: 124 additions & 0 deletions src/automul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
By default, elementwise multiplication will be performed.
This function is kind of deprecated and will be removed in the future.
"""
function automul(
M1::BlockedMPS,
M2::BlockedMPS;
tag_row::String="",
tag_shared::String="",
tag_col::String="",
alg="naive",
maxdim=typemax(Int),
cutoff=1e-25,
kwargs...,
)
all(length.(siteinds(M1)) .== 1) || error("M1 should have only 1 site index per site")
all(length.(siteinds(M2)) .== 1) || error("M2 should have only 1 site index per site")

sites_row = _findallsiteinds_by_tag(M1; tag=tag_row)
sites_shared = _findallsiteinds_by_tag(M1; tag=tag_shared)
sites_col = _findallsiteinds_by_tag(M2; tag=tag_col)
sites_matmul = Set(Iterators.flatten([sites_row, sites_shared, sites_col]))

sites1 = only.(siteinds(M1))
sites1_ewmul = setdiff(only.(siteinds(M1)), sites_matmul)
sites2_ewmul = setdiff(only.(siteinds(M2)), sites_matmul)
sites2_ewmul == sites1_ewmul || error("Invalid sites for elementwise multiplication")

M1 = _makesitediagonal(M1, sites1_ewmul; baseplev=1)
M2 = _makesitediagonal(M2, sites2_ewmul; baseplev=0)

sites_M1_diag = [collect(x) for x in siteinds(M1)]
sites_M2_diag = [collect(x) for x in siteinds(M2)]

M1 = rearrange_siteinds(M1, combinesites(sites_M1_diag, sites_row, sites_shared))

M2 = rearrange_siteinds(M2, combinesites(sites_M2_diag, sites_shared, sites_col))

M = contract(M1, M2; alg=alg, kwargs...)

M = extractdiagonal(M, sites1_ewmul)

ressites = Vector{eltype(siteinds(M1)[1])}[]
for s in siteinds(M)
s_ = unique(ITensors.noprime.(s))
if length(s_) == 1
push!(ressites, s_)
else
if s_[1] sites1
push!(ressites, [s_[1]])
push!(ressites, [s_[2]])
else
push!(ressites, [s_[2]])
push!(ressites, [s_[1]])
end
end
end
return truncate(rearrange_siteinds(M, ressites); cutoff=cutoff, maxdim=maxdim)
end

function combinesites(
sites::Vector{Vector{Index{IndsT}}},
site1::AbstractVector{Index{IndsT}},
site2::AbstractVector{Index{IndsT}},
) where {IndsT}
length(site1) == length(site2) || error("Length mismatch")
for (s1, s2) in zip(site1, site2)
sites = combinesites(sites, s1, s2)
end
return sites
end

function combinesites(
sites::Vector{Vector{Index{IndsT}}}, site1::Index, site2::Index
) where {IndsT}
sites = deepcopy(sites)
p1 = findfirst(x -> x[1] == site1, sites)
p2 = findfirst(x -> x[1] == site2, sites)
if p1 === nothing || p2 === nothing
error("Site not found")
end
if abs(p1 - p2) != 1
error("Sites are not adjacent")
end
deleteat!(sites, min(p1, p2))
deleteat!(sites, min(p1, p2))
insert!(sites, min(p1, p2), [site1, site2])
return sites
end

function _findallsiteinds_by_tag(M::BlockedMPS; tag=tag)
return findallsiteinds_by_tag(only.(siteinds(M)); tag=tag)
end

# The following code is copied from Quantics.jl

function findallsiteinds_by_tag(
sites::AbstractVector{Index{T}}; tag::String="x", maxnsites::Int=1000
) where {T}
_valid_tag(tag) || error("Invalid tag: $tag")
positions = findallsites_by_tag(sites; tag=tag, maxnsites=maxnsites)
return [sites[p] for p in positions]
end

function findallsites_by_tag(
sites::Vector{Index{T}}; tag::String="x", maxnsites::Int=1000
)::Vector{Int} where {T}
_valid_tag(tag) || error("Invalid tag: $tag")
result = Int[]
for n in 1:maxnsites
tag_ = tag * "=$n"
idx = findall(hastags(tag_), sites)
if length(idx) == 0
break
elseif length(idx) > 1
error("Found more than one site indices with $(tag_)!")
end
push!(result, idx[1])
end
return result
end

_valid_tag(tag::String)::Bool = !occursin("=", tag)
24 changes: 23 additions & 1 deletion src/blockedmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ function Base.values(obj::BlockedMPS)
return values(obj.data)
end


"""
Rearrange the site indices of the BlockedMPS according to the given order.
If nessecary, tensors are fused or split to match the new order.
Expand Down Expand Up @@ -144,3 +143,26 @@ end
function ITensorMPS.MPO(obj::BlockedMPS; cutoff=1e-25, maxdim=typemax(Int))::MPO
return MPO(collect(MPS(obj; cutoff=cutoff, maxdim=maxdim, kwargs...)))
end

"""
Make the BlockedMPS diagonal for a given site index `s` by introducing a dummy index `s'`.
"""
function makesitediagonal(obj::BlockedMPS, site)
return BlockedMPS([
_makesitediagonal(prjmps, site; baseplev=baseplev) for prjmps in values(obj)
])
end

function _makesitediagonal(obj::BlockedMPS, site; baseplev=0)
return BlockedMPS([
_makesitediagonal(prjmps, site; baseplev=baseplev) for prjmps in values(obj)
])
end

"""
Extract diagonal of the BlockedMPS for `s`, `s'`, ... for a given site index `s`,
where `s` must have a prime level of 0.
"""
function extractdiagonal(obj::BlockedMPS, site)
return BlockedMPS([extractdiagonal(prjmps, site) for prjmps in values(obj)])
end
85 changes: 85 additions & 0 deletions src/projmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,88 @@ end
function LinearAlgebra.norm(M::ProjMPS)
return _norm(MPS(M))
end

function _makesitediagonal(
projmps::ProjMPS, sites::AbstractVector{Index{IndsT}}; baseplev=0
) where {IndsT}
M_ = deepcopy(MPO(collect(MPS(projmps))))
for site in sites
target_site::Int = only(findsites(M_, site))
M_[target_site] = _asdiagonal(M_[target_site], site; baseplev=baseplev)
end
return project(M_, projmps.projector)
end

function _makesitediagonal(projmps::ProjMPS, site::Index; baseplev=0)
return _makesitediagonal(projmps, [site]; baseplev=baseplev)
end

function makesitediagonal(projmps::ProjMPS, site::Index)
return _makesitediagonal(projmps, site; baseplev=0)
end

function makesitediagonal(projmps::ProjMPS, sites::AbstractVector{Index})
return _makesitediagonal(projmps, sites; baseplev=0)
end

function makesitediagonal(projmps::ProjMPS, tag::String)
mps_diagonal = Quantics.makesitediagonal(MPS(projmps), tag)
projmps_diagonal = ProjMPS(mps_diagonal)

target_sites = Quantics.findallsiteinds_by_tag(
unique(ITensors.noprime.(Iterators.flatten(siteinds(projmps)))); tag=tag
)

newproj = deepcopy(projmps.projector)
for s in target_sites
if isprojectedat(projmps.projector, s)
newproj[ITensors.prime(s)] = newproj[s]
end
end

return project(projmps_diagonal, newproj)
end

# FIXME: may be type unstable
function _find_site_allplevs(tensor::ITensor, site::Index; maxplev=10)
ITensors.plev(site) == 0 || error("Site index must be unprimed.")
return [
ITensors.prime(site, plev) for
plev in 0:maxplev if ITensors.prime(site, plev) ITensors.inds(tensor)
]
end

function extractdiagonal(
projmps::ProjMPS, sites::AbstractVector{Index{IndsT}}
) where {IndsT}
tensors = collect(projmps.data)
for i in eachindex(tensors)
for site in intersect(sites, ITensors.inds(tensors[i]))
sitewithallplevs = _find_site_allplevs(tensors[i], site)
tensors[i] = if length(sitewithallplevs) > 1
tensors[i] = _extract_diagonal(tensors[i], sitewithallplevs...)
else
tensors[i]
end
end
end

projector = deepcopy(projmps.projector)
for site in sites
if site' in keys(projector.data)
delete!(projector.data, site')
end
end
return ProjMPS(MPS(tensors), projector)
end

function extractdiagonal(projmps::ProjMPS, site::Index{IndsT}) where {IndsT}
return Quantics.extractdiagonal(projmps, [site])
end

function extractdiagonal(projmps::ProjMPS, tag::String)::ProjMPS
targetsites = Quantics.findallsiteinds_by_tag(
unique(ITensors.noprime.(ProjMPSs._allsites(projmps))); tag=tag
)
return extractdiagonal(projmps, targetsites)
end
14 changes: 12 additions & 2 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ function _asdiagonal(t, site::Index{T}; baseplev=0)::ITensor where {T<:Number}
)
end


function rearrange_siteinds(M::AbstractMPS, sites::Vector{Vector{Index{T}}})::MPS where {T}
sitesold = siteinds(MPO(collect(M)))

Expand Down Expand Up @@ -104,5 +103,16 @@ function rearrange_siteinds(M::AbstractMPS, sites::Vector{Vector{Index{T}}})::MP
tensors[i], t, _ = qr(t, linds)
end
tensors[end] *= t
MPS(tensors)
return MPS(tensors)
end

function _extract_diagonal(t, site::Index{T}, site2::Index{T}) where {T<:Number}
dim(site) == dim(site2) || error("Dimension mismatch")
restinds = uniqueinds(inds(t), site, site2)
newdata = zeros(eltype(t), dim.(restinds)..., dim(site))
olddata = Array(t, restinds..., site, site2)
for i in 1:dim(site)
newdata[.., i] = olddata[.., i, i]
end
return ITensor(newdata, restinds..., site)
end
72 changes: 72 additions & 0 deletions test/automul_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Test

using ITensors
import ProjMPSs: ProjMPSs, Projector, project, ProjMPS, projcontract, BlockedMPS
import FastMPOContractions as FMPOC
using Quantics: Quantics

@testset "mul.jl" begin
"""
Reconstruct 3D matrix
"""
function _tomat3(a::MPS)
sites = siteinds(a)
N = length(sites)
Nreduced = N ÷ 3
sites_ = [sites[1:3:N]..., sites[2:3:N]..., sites[3:3:N]...]
return reshape(Array(reduce(*, a), sites_), 2^Nreduced, 2^Nreduced, 2^Nreduced)
end

@testset "batchedmatmul" for T in [Float64]
"""
C(x, z, k) = sum_y A(x, y, k) * B(y, z, k)
"""
nbit = 2
D = 2
cutoff = 1e-25
sx = [Index(2, "Qubit,x=$n") for n in 1:nbit]
sy = [Index(2, "Qubit,y=$n") for n in 1:nbit]
sz = [Index(2, "Qubit,z=$n") for n in 1:nbit]
sk = [Index(2, "Qubit,k=$n") for n in 1:nbit]

sites_a = collect(Iterators.flatten(zip(sx, sy, sk)))
sites_b = collect(Iterators.flatten(zip(sy, sz, sk)))

a = random_mps(T, sites_a; linkdims=D)
b = random_mps(T, sites_b; linkdims=D)

# Reference data
a_arr = _tomat3(a)
b_arr = _tomat3(b)
ab_arr = zeros(T, 2^nbit, 2^nbit, 2^nbit)
for k in 1:(2^nbit)
ab_arr[:, :, k] .= a_arr[:, :, k] * b_arr[:, :, k]
end

a_ = BlockedMPS([
project(a, Projector(Dict(sx[1] => 1, sy[1] => 1))),
project(a, Projector(Dict(sx[1] => 1, sy[1] => 2))),
project(a, Projector(Dict(sx[1] => 2, sy[1] => 1))),
project(a, Projector(Dict(sx[1] => 2, sy[1] => 2))),
])

b_ = BlockedMPS([
project(b, Projector(Dict(sy[1] => 1, sz[1] => 1))),
project(b, Projector(Dict(sy[1] => 1, sz[1] => 2))),
project(b, Projector(Dict(sy[1] => 2, sz[1] => 1))),
project(b, Projector(Dict(sy[1] => 2, sz[1] => 2))),
])

@test a MPS(a_)
@test b MPS(b_)

ab = ProjMPSs.automul(
a_, b_; tag_row="x", tag_shared="y", tag_col="z", alg="fit", cutoff
)
ab_ref = Quantics.automul(
a, b; tag_row="x", tag_shared="y", tag_col="z", alg="fit", cutoff
)

@test MPS(ab) ab_ref rtol = 10 * sqrt(cutoff)
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ include("projmps_tests.jl")
include("blockedmps_tests.jl")
include("contract_tests.jl")
include("patching_tests.jl")

include("automul_tests.jl")

0 comments on commit 7a3b9c6

Please sign in to comment.