diff --git a/Project.toml b/Project.toml index e118dd5..7c5a708 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ julia = "1.6" [extras] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Quantics = "87f76fb3-a40a-40c9-a63c-29fcfe7b7547" [targets] -test = ["Test", "Random"] +test = ["Test", "Random", "Quantics"] diff --git a/src/ProjMPSs.jl b/src/ProjMPSs.jl index 109585d..2a3963a 100644 --- a/src/ProjMPSs.jl +++ b/src/ProjMPSs.jl @@ -2,14 +2,14 @@ module ProjMPSs import OrderedCollections: OrderedSet, OrderedDict using EllipsisNotation -import LinearAlgebra +using LinearAlgebra: LinearAlgebra -import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds -import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds +import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds, uniqueinds +import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds, findsites +import ITensors.TagSets: hastag, hastags import FastMPOContractions as FMPOC - include("util.jl") include("projector.jl") include("projmps.jl") @@ -18,4 +18,7 @@ include("patching.jl") include("contract.jl") include("adaptivemul.jl") +# Only for backward compatibility +include("automul.jl") + end diff --git a/src/automul.jl b/src/automul.jl new file mode 100644 index 0000000..aece578 --- /dev/null +++ b/src/automul.jl @@ -0,0 +1,124 @@ +""" +By default, elementwise multiplication will be performed. + +This function is kind of deprecated and will be removed in the future. +""" +function automul( + M1::BlockedMPS, + M2::BlockedMPS; + tag_row::String="", + tag_shared::String="", + tag_col::String="", + alg="naive", + maxdim=typemax(Int), + cutoff=1e-25, + kwargs..., +) + all(length.(siteinds(M1)) .== 1) || error("M1 should have only 1 site index per site") + all(length.(siteinds(M2)) .== 1) || error("M2 should have only 1 site index per site") + + sites_row = _findallsiteinds_by_tag(M1; tag=tag_row) + sites_shared = _findallsiteinds_by_tag(M1; tag=tag_shared) + sites_col = _findallsiteinds_by_tag(M2; tag=tag_col) + sites_matmul = Set(Iterators.flatten([sites_row, sites_shared, sites_col])) + + sites1 = only.(siteinds(M1)) + sites1_ewmul = setdiff(only.(siteinds(M1)), sites_matmul) + sites2_ewmul = setdiff(only.(siteinds(M2)), sites_matmul) + sites2_ewmul == sites1_ewmul || error("Invalid sites for elementwise multiplication") + + M1 = _makesitediagonal(M1, sites1_ewmul; baseplev=1) + M2 = _makesitediagonal(M2, sites2_ewmul; baseplev=0) + + sites_M1_diag = [collect(x) for x in siteinds(M1)] + sites_M2_diag = [collect(x) for x in siteinds(M2)] + + M1 = rearrange_siteinds(M1, combinesites(sites_M1_diag, sites_row, sites_shared)) + + M2 = rearrange_siteinds(M2, combinesites(sites_M2_diag, sites_shared, sites_col)) + + M = contract(M1, M2; alg=alg, kwargs...) + + M = extractdiagonal(M, sites1_ewmul) + + ressites = Vector{eltype(siteinds(M1)[1])}[] + for s in siteinds(M) + s_ = unique(ITensors.noprime.(s)) + if length(s_) == 1 + push!(ressites, s_) + else + if s_[1] ∈ sites1 + push!(ressites, [s_[1]]) + push!(ressites, [s_[2]]) + else + push!(ressites, [s_[2]]) + push!(ressites, [s_[1]]) + end + end + end + return truncate(rearrange_siteinds(M, ressites); cutoff=cutoff, maxdim=maxdim) +end + +function combinesites( + sites::Vector{Vector{Index{IndsT}}}, + site1::AbstractVector{Index{IndsT}}, + site2::AbstractVector{Index{IndsT}}, +) where {IndsT} + length(site1) == length(site2) || error("Length mismatch") + for (s1, s2) in zip(site1, site2) + sites = combinesites(sites, s1, s2) + end + return sites +end + +function combinesites( + sites::Vector{Vector{Index{IndsT}}}, site1::Index, site2::Index +) where {IndsT} + sites = deepcopy(sites) + p1 = findfirst(x -> x[1] == site1, sites) + p2 = findfirst(x -> x[1] == site2, sites) + if p1 === nothing || p2 === nothing + error("Site not found") + end + if abs(p1 - p2) != 1 + error("Sites are not adjacent") + end + deleteat!(sites, min(p1, p2)) + deleteat!(sites, min(p1, p2)) + insert!(sites, min(p1, p2), [site1, site2]) + return sites +end + +function _findallsiteinds_by_tag(M::BlockedMPS; tag=tag) + return findallsiteinds_by_tag(only.(siteinds(M)); tag=tag) +end + +# The following code is copied from Quantics.jl + +function findallsiteinds_by_tag( + sites::AbstractVector{Index{T}}; tag::String="x", maxnsites::Int=1000 +) where {T} + _valid_tag(tag) || error("Invalid tag: $tag") + positions = findallsites_by_tag(sites; tag=tag, maxnsites=maxnsites) + return [sites[p] for p in positions] +end + +function findallsites_by_tag( + sites::Vector{Index{T}}; tag::String="x", maxnsites::Int=1000 +)::Vector{Int} where {T} + _valid_tag(tag) || error("Invalid tag: $tag") + result = Int[] + for n in 1:maxnsites + tag_ = tag * "=$n" + idx = findall(hastags(tag_), sites) + if length(idx) == 0 + break + elseif length(idx) > 1 + error("Found more than one site indices with $(tag_)!") + end + push!(result, idx[1]) + end + return result +end + +_valid_tag(tag::String)::Bool = !occursin("=", tag) diff --git a/src/blockedmps.jl b/src/blockedmps.jl index 3153796..5424515 100644 --- a/src/blockedmps.jl +++ b/src/blockedmps.jl @@ -68,7 +68,6 @@ function Base.values(obj::BlockedMPS) return values(obj.data) end - """ Rearrange the site indices of the BlockedMPS according to the given order. If nessecary, tensors are fused or split to match the new order. @@ -144,3 +143,26 @@ end function ITensorMPS.MPO(obj::BlockedMPS; cutoff=1e-25, maxdim=typemax(Int))::MPO return MPO(collect(MPS(obj; cutoff=cutoff, maxdim=maxdim, kwargs...))) end + +""" +Make the BlockedMPS diagonal for a given site index `s` by introducing a dummy index `s'`. +""" +function makesitediagonal(obj::BlockedMPS, site) + return BlockedMPS([ + _makesitediagonal(prjmps, site; baseplev=baseplev) for prjmps in values(obj) + ]) +end + +function _makesitediagonal(obj::BlockedMPS, site; baseplev=0) + return BlockedMPS([ + _makesitediagonal(prjmps, site; baseplev=baseplev) for prjmps in values(obj) + ]) +end + +""" +Extract diagonal of the BlockedMPS for `s`, `s'`, ... for a given site index `s`, +where `s` must have a prime level of 0. +""" +function extractdiagonal(obj::BlockedMPS, site) + return BlockedMPS([extractdiagonal(prjmps, site) for prjmps in values(obj)]) +end diff --git a/src/projmps.jl b/src/projmps.jl index 3303b6e..ce3dec9 100644 --- a/src/projmps.jl +++ b/src/projmps.jl @@ -189,3 +189,88 @@ end function LinearAlgebra.norm(M::ProjMPS) return _norm(MPS(M)) end + +function _makesitediagonal( + projmps::ProjMPS, sites::AbstractVector{Index{IndsT}}; baseplev=0 +) where {IndsT} + M_ = deepcopy(MPO(collect(MPS(projmps)))) + for site in sites + target_site::Int = only(findsites(M_, site)) + M_[target_site] = _asdiagonal(M_[target_site], site; baseplev=baseplev) + end + return project(M_, projmps.projector) +end + +function _makesitediagonal(projmps::ProjMPS, site::Index; baseplev=0) + return _makesitediagonal(projmps, [site]; baseplev=baseplev) +end + +function makesitediagonal(projmps::ProjMPS, site::Index) + return _makesitediagonal(projmps, site; baseplev=0) +end + +function makesitediagonal(projmps::ProjMPS, sites::AbstractVector{Index}) + return _makesitediagonal(projmps, sites; baseplev=0) +end + +function makesitediagonal(projmps::ProjMPS, tag::String) + mps_diagonal = Quantics.makesitediagonal(MPS(projmps), tag) + projmps_diagonal = ProjMPS(mps_diagonal) + + target_sites = Quantics.findallsiteinds_by_tag( + unique(ITensors.noprime.(Iterators.flatten(siteinds(projmps)))); tag=tag + ) + + newproj = deepcopy(projmps.projector) + for s in target_sites + if isprojectedat(projmps.projector, s) + newproj[ITensors.prime(s)] = newproj[s] + end + end + + return project(projmps_diagonal, newproj) +end + +# FIXME: may be type unstable +function _find_site_allplevs(tensor::ITensor, site::Index; maxplev=10) + ITensors.plev(site) == 0 || error("Site index must be unprimed.") + return [ + ITensors.prime(site, plev) for + plev in 0:maxplev if ITensors.prime(site, plev) ∈ ITensors.inds(tensor) + ] +end + +function extractdiagonal( + projmps::ProjMPS, sites::AbstractVector{Index{IndsT}} +) where {IndsT} + tensors = collect(projmps.data) + for i in eachindex(tensors) + for site in intersect(sites, ITensors.inds(tensors[i])) + sitewithallplevs = _find_site_allplevs(tensors[i], site) + tensors[i] = if length(sitewithallplevs) > 1 + tensors[i] = _extract_diagonal(tensors[i], sitewithallplevs...) + else + tensors[i] + end + end + end + + projector = deepcopy(projmps.projector) + for site in sites + if site' in keys(projector.data) + delete!(projector.data, site') + end + end + return ProjMPS(MPS(tensors), projector) +end + +function extractdiagonal(projmps::ProjMPS, site::Index{IndsT}) where {IndsT} + return Quantics.extractdiagonal(projmps, [site]) +end + +function extractdiagonal(projmps::ProjMPS, tag::String)::ProjMPS + targetsites = Quantics.findallsiteinds_by_tag( + unique(ITensors.noprime.(ProjMPSs._allsites(projmps))); tag=tag + ) + return extractdiagonal(projmps, targetsites) +end diff --git a/src/util.jl b/src/util.jl index 4d58efa..1c03577 100644 --- a/src/util.jl +++ b/src/util.jl @@ -71,7 +71,6 @@ function _asdiagonal(t, site::Index{T}; baseplev=0)::ITensor where {T<:Number} ) end - function rearrange_siteinds(M::AbstractMPS, sites::Vector{Vector{Index{T}}})::MPS where {T} sitesold = siteinds(MPO(collect(M))) @@ -104,5 +103,16 @@ function rearrange_siteinds(M::AbstractMPS, sites::Vector{Vector{Index{T}}})::MP tensors[i], t, _ = qr(t, linds) end tensors[end] *= t - MPS(tensors) + return MPS(tensors) +end + +function _extract_diagonal(t, site::Index{T}, site2::Index{T}) where {T<:Number} + dim(site) == dim(site2) || error("Dimension mismatch") + restinds = uniqueinds(inds(t), site, site2) + newdata = zeros(eltype(t), dim.(restinds)..., dim(site)) + olddata = Array(t, restinds..., site, site2) + for i in 1:dim(site) + newdata[.., i] = olddata[.., i, i] + end + return ITensor(newdata, restinds..., site) end diff --git a/test/automul_tests.jl b/test/automul_tests.jl new file mode 100644 index 0000000..0207b56 --- /dev/null +++ b/test/automul_tests.jl @@ -0,0 +1,72 @@ +using Test + +using ITensors +import ProjMPSs: ProjMPSs, Projector, project, ProjMPS, projcontract, BlockedMPS +import FastMPOContractions as FMPOC +using Quantics: Quantics + +@testset "mul.jl" begin + """ + Reconstruct 3D matrix + """ + function _tomat3(a::MPS) + sites = siteinds(a) + N = length(sites) + Nreduced = N ÷ 3 + sites_ = [sites[1:3:N]..., sites[2:3:N]..., sites[3:3:N]...] + return reshape(Array(reduce(*, a), sites_), 2^Nreduced, 2^Nreduced, 2^Nreduced) + end + + @testset "batchedmatmul" for T in [Float64] + """ + C(x, z, k) = sum_y A(x, y, k) * B(y, z, k) + """ + nbit = 2 + D = 2 + cutoff = 1e-25 + sx = [Index(2, "Qubit,x=$n") for n in 1:nbit] + sy = [Index(2, "Qubit,y=$n") for n in 1:nbit] + sz = [Index(2, "Qubit,z=$n") for n in 1:nbit] + sk = [Index(2, "Qubit,k=$n") for n in 1:nbit] + + sites_a = collect(Iterators.flatten(zip(sx, sy, sk))) + sites_b = collect(Iterators.flatten(zip(sy, sz, sk))) + + a = random_mps(T, sites_a; linkdims=D) + b = random_mps(T, sites_b; linkdims=D) + + # Reference data + a_arr = _tomat3(a) + b_arr = _tomat3(b) + ab_arr = zeros(T, 2^nbit, 2^nbit, 2^nbit) + for k in 1:(2^nbit) + ab_arr[:, :, k] .= a_arr[:, :, k] * b_arr[:, :, k] + end + + a_ = BlockedMPS([ + project(a, Projector(Dict(sx[1] => 1, sy[1] => 1))), + project(a, Projector(Dict(sx[1] => 1, sy[1] => 2))), + project(a, Projector(Dict(sx[1] => 2, sy[1] => 1))), + project(a, Projector(Dict(sx[1] => 2, sy[1] => 2))), + ]) + + b_ = BlockedMPS([ + project(b, Projector(Dict(sy[1] => 1, sz[1] => 1))), + project(b, Projector(Dict(sy[1] => 1, sz[1] => 2))), + project(b, Projector(Dict(sy[1] => 2, sz[1] => 1))), + project(b, Projector(Dict(sy[1] => 2, sz[1] => 2))), + ]) + + @test a ≈ MPS(a_) + @test b ≈ MPS(b_) + + ab = ProjMPSs.automul( + a_, b_; tag_row="x", tag_shared="y", tag_col="z", alg="fit", cutoff + ) + ab_ref = Quantics.automul( + a, b; tag_row="x", tag_shared="y", tag_col="z", alg="fit", cutoff + ) + + @test MPS(ab) ≈ ab_ref rtol = 10 * sqrt(cutoff) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 49d343c..5eede2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,3 +11,5 @@ include("projmps_tests.jl") include("blockedmps_tests.jl") include("contract_tests.jl") include("patching_tests.jl") + +include("automul_tests.jl")