Skip to content

Commit

Permalink
Implement adaptive adjustment of cutoff in truncate
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Dec 4, 2024
1 parent 75a626d commit e024840
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 13 deletions.
3 changes: 3 additions & 0 deletions src/PartitionedMPSs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import ITensors.TagSets: hastag, hastags

import FastMPOContractions as FMPOC

default_cutoff() = 1e-25
default_maxdim() = typemax(Int)

include("util.jl")
include("projector.jl")
include("subdomainmps.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/adaptivemul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ function adaptivecontract(
b::PartitionedMPS,
pordering::AbstractVector{Index}=Index[];
alg="fit",
cutoff=1e-25,
maxdim=typemax(Int),
cutoff=default_cutoff(),
maxdim=default_maxdim(),
kwargs...,
)
patches = Dict{Projector,Vector{Union{SubDomainMPS,LazyContraction}}}()
Expand Down
12 changes: 6 additions & 6 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ function projcontract(
M2::SubDomainMPS,
proj::Projector;
alg="fit",
cutoff=1e-25,
maxdim=typemax(Int),
cutoff=default_cutoff(),
maxdim=default_maxdim(),
verbosity=0,
kwargs...,
)::Union{Nothing,SubDomainMPS}
Expand Down Expand Up @@ -92,8 +92,8 @@ function projcontract(
proj::Projector;
alg="fit",
alg_sum="fit",
cutoff=1e-25,
maxdim=typemax(Int),
cutoff=default_cutoff(),
maxdim=default_maxdim(),
patchorder=Index[],
kwargs...,
)::Union{Nothing,Vector{SubDomainMPS}}
Expand Down Expand Up @@ -140,8 +140,8 @@ function contract(
M1::PartitionedMPS,
M2::PartitionedMPS;
alg="fit",
cutoff=1e-25,
maxdim=typemax(Int),
cutoff=default_cutoff(),
maxdim=default_maxdim(),
patchorder=Index[],
kwargs...,
)::Union{PartitionedMPS}
Expand Down
48 changes: 44 additions & 4 deletions src/partitionedmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,54 @@ function Base.:-(obj::PartitionedMPS)::PartitionedMPS
return -1 * obj
end

function truncate(obj::PartitionedMPS; kwargs...)::PartitionedMPS
return PartitionedMPS([truncate(v; kwargs...) for v in values(obj)])
"""
Truncate a PartitionedMPS object piecewise.
Each SubDomainMPS in the PartitionedMPS is truncated independently,
but the cutoff is adjusted according to the norm of each SubDomainMPS.
The total error is the sum of the errors in each SubDomainMPS.
"""
function truncate(
obj::PartitionedMPS;
cutoff=default_cutoff(),
maxdim=default_maxdim(),
use_adaptive_weight=true,
kwargs...,
)::PartitionedMPS
norm2 = [LinearAlgebra.norm(v)^2 for v in values(obj)]
total_norm2 = sum(norm2)
weights = [total_norm2 / norm2_v for norm2_v in norm2] # Initial weights (FIXME: better choice?)

compressed = obj

while true
compressed = PartitionedMPS([
truncate(v; cutoff=cutoff * w, maxdim, kwargs...) for
(v, w) in zip(values(obj), weights)
])
actual_error = dist(obj, compressed)^2 / sum(norm2)
if actual_error < cutoff || !use_adaptive_weight
break
end

weights .*= cutoff / actual_error # Adjust weights
end

return compressed
end

# Only for debug
function ITensorMPS.MPS(obj::PartitionedMPS; cutoff=1e-25, maxdim=typemax(Int))::MPS
function ITensorMPS.MPS(
obj::PartitionedMPS; cutoff=default_cutoff(), maxdim=default_maxdim()
)::MPS
return reduce(
(x, y) -> truncate(+(x, y; alg="directsum"); cutoff, maxdim), values(obj.data)
).data # direct sum
end

function ITensorMPS.MPO(obj::PartitionedMPS; cutoff=1e-25, maxdim=typemax(Int))::MPO
function ITensorMPS.MPO(
obj::PartitionedMPS; cutoff=default_cutoff(), maxdim=default_maxdim()
)::MPO
return MPO(collect(MPS(obj; cutoff=cutoff, maxdim=maxdim, kwargs...)))
end

Expand All @@ -168,3 +204,7 @@ where `s` must have a prime level of 0.
function extractdiagonal(obj::PartitionedMPS, site)
return PartitionedMPS([extractdiagonal(prjmps, site) for prjmps in values(obj)])
end

function dist(a::PartitionedMPS, b::PartitionedMPS)
return sqrt(sum(ITensorMPS.dist(MPS(a[k]), MPS(b[k]))^2 for k in keys(a)))
end
6 changes: 6 additions & 0 deletions src/subdomainmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ function project(
return project(projΨ, Projector(projector))
end

function project(
Ψ::AbstractMPS, projector::Dict{InsT,Int}
)::Union{Nothing,SubDomainMPS} where {InsT}
return project(SubDomainMPS(Ψ), Projector(projector))
end

function _iscompatible(projector::Projector, tensor::ITensor)
# Lazy implementation
return ITensors.norm(project(tensor, projector) - tensor) == 0.0
Expand Down
27 changes: 26 additions & 1 deletion test/partitionedmps_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ using Test

using ITensors
using ITensorMPS
using Random

import PartitionedMPSs: Projector, project, SubDomainMPS, PartitionedMPS
import PartitionedMPSs: PartitionedMPSs, Projector, project, SubDomainMPS, PartitionedMPS

@testset "partitionedmps.jl" begin
@testset "two blocks" begin
Random.seed!(1234)
N = 3
sitesx = [Index(2, "x=$n") for n in 1:N]
sitesy = [Index(2, "y=$n") for n in 1:N]
Expand Down Expand Up @@ -35,6 +37,7 @@ import PartitionedMPSs: Projector, project, SubDomainMPS, PartitionedMPS
end

@testset "two blocks (general key)" begin
Random.seed!(1234)
N = 3
sitesx = [Index(2, "x=$n") for n in 1:N]
sitesy = [Index(2, "y=$n") for n in 1:N]
Expand All @@ -56,4 +59,26 @@ import PartitionedMPSs: Projector, project, SubDomainMPS, PartitionedMPS
@test MPS((a + b) + 2 * (b + a)) 3 * Ψ rtol = 1e-13
@test MPS((a + b) + 2 * (b + a)) 3 * Ψ rtol = 1e-13
end

@testset "truncate" begin
for seed in [1, 2, 3, 4, 5]
Random.seed!(seed)
N = 10
D = 10 # Bond dimension
d = 10 # local dimension
cutoff_global = 1e-4

sites = [[Index(d, "n=$n")] for n in 1:N]

Ψ = 100 * MPS(collect(_random_mpo(sites; linkdims=D)))

partmps = PartitionedMPS([project(Ψ, Dict(sites[1][1] => d_)) for d_ in 1:d])
partmps_truncated = PartitionedMPSs.truncate(partmps; cutoff=cutoff_global)

diff =
ITensorMPS.dist(MPS(partmps_truncated), MPS(partmps))^2 /
ITensorMPS.norm(MPS(partmps))^2
@test diff < cutoff_global
end
end
end

0 comments on commit e024840

Please sign in to comment.