Skip to content

Commit

Permalink
Merge pull request #41 from numericalEFT/bugfix
Browse files Browse the repository at this point in the history
fix MPI reweight bug and add its unit test
  • Loading branch information
kunyuan authored Apr 10, 2023
2 parents 08b9c56 + 4c6fcba commit 1accc98
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCIntegration"
uuid = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
authors = ["Kun Chen", "Xiansheng Cai", "Pengcheng Hou"]
version = "0.3.4"
version = "0.3.5"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
24 changes: 16 additions & 8 deletions src/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
- `block`: Number of blocks. Each block will be evaluated by about neval/block times. Each block is assumed to be statistically independent, and will be used to estimate the error.
In MPI mode, the blocks are distributed among the workers. If the numebr of workers N is larger than block, then block will be set to be N.
- `print`: -2 to not print anything; -1 to print minimal information; 0 to print the iteration history in the end; >0 to print MC configuration for every `print` seconds and print the iteration history in the end.
- `gamma`: Learning rate of the reweight factor after each iteraction. Note that alpha <=1, where alpha = 0 means no reweighting.
- `gamma`: Learning rate of the reweight factor after each iteraction. Note that gamma <=1, where gamma = 0 means no reweighting.
- `adapt`: Whether to adapt the grid and the reweight factor.
- `debug`: Whether to print debug information (type instability, float overflow etc.)
- `reweight_goal`: The expected distribution of visited times for each integrand after reweighting . If not set, then all factors will be initialized with one. Only useful for the :mcmc solver.
Expand Down Expand Up @@ -158,11 +158,8 @@ function integrate(integrand::Function;
# collect all statistics to summedConfig of the root worker
MPIreduceConfig!(summedConfig[1])


if MCUtility.mpi_master() # only the master process will output results, no matter parallel = :mpi or :thread or :serial
################### self-learning ##########################################
(solver == :mcmc || solver == :vegasmc) && doReweight!(summedConfig[1], gamma, reweight_goal)
end
######################## self-learning #########################################
(solver == :mcmc || solver == :vegasmc) && doReweightMPI!(summedConfig[1], gamma, reweight_goal, comm)

######################## syncronize between works ##############################

Expand Down Expand Up @@ -304,8 +301,9 @@ function doReweight!(config, gamma, reweight_goal)
end
# println(config.visited)
# println(config.reweight)
if isnothing(reweight_goal) == false
config.reweight .*= reweight_goal
if !isnothing(reweight_goal) # Apply reweight_goal if provided
# config.reweight .*= reweight_goal
config.reweight .*= reweight_goal ./ sum(reweight_goal)
end
# renoormalize all reweight to be (0.0, 1.0)
config.reweight ./= sum(config.reweight)
Expand All @@ -315,4 +313,14 @@ function doReweight!(config, gamma, reweight_goal)
# Check Eq. (19) of https://arxiv.org/pdf/2009.05112.pdf for more detail
# config.reweight = @. ((1 - config.reweight) / log(1 / config.reweight))^beta
# config.reweight ./= sum(config.reweight)
end

function doReweightMPI!(config::Configuration, gamma, reweight_goal::Union{Vector{Float64},Nothing}, comm::MPI.Comm)
if MCUtility.mpi_master()
# only the master process will output results, no matter parallel = :mpi or :thread or :serial
doReweight!(config, gamma, reweight_goal)
end
reweight_array = Vector{Float64}(config.reweight)
MPI.Bcast!(reweight_array, 0, comm)
config.reweight .= reweight_array
end
51 changes: 37 additions & 14 deletions test/mpi_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const MCUtility = MCIntegration.MCUtility
rank = MPI.Comm_rank(comm) # rank of current MPI worker
root = 0 # rank of the root worker

a = [1, 2, 3]
a = [1, 2, 3]
aa = MCUtility.MPIreduce(a)
if rank == root
@test aa == [Nworker, 2Nworker, 3Nworker]
Expand All @@ -29,7 +29,7 @@ const MCUtility = MCIntegration.MCUtility
end

# inplace
a = [1, 2, 3]
a = [1, 2, 3]
MCUtility.MPIreduce!(a)
if rank == root
@test a == [Nworker, 2Nworker, 3Nworker]
Expand All @@ -43,7 +43,7 @@ end
rank = MPI.Comm_rank(comm) # rank of current MPI worker
root = 0 # rank of the root worker

a = [1, 2, 3] .* rank
a = [1, 2, 3] .* rank
aa = MCUtility.MPIbcast(a)
if rank != root
@test aa == [0, 0, 0]
Expand All @@ -62,7 +62,7 @@ end
end

# inplace
a = [1, 2, 3] .* rank
a = [1, 2, 3] .* rank
MCUtility.MPIbcast!(a)
if rank != root
@test a == [0, 0, 0]
Expand All @@ -85,7 +85,7 @@ end
Z.histogram[1] = 1.3
cvar = CompositeVar(Y, Z)
obs = [1.0,]
config = Configuration(var = (X, cvar), dof=[[1, 1], ], obs=obs)
config = Configuration(var=(X, cvar), dof=[[1, 1],], obs=obs)
config.neval = 101
config.normalization = 1.1
config.visited[1] = 1.2
Expand All @@ -95,16 +95,16 @@ end
MCIntegration.MPIreduceConfig!(config)
if rank == root
@test config.observable[1] == Nworker
@test config.neval == Nworker*101
@test config.normalization Nworker*1.1
@test config.visited[1] Nworker*1.2
@test config.propose[1, 1, 1] Nworker*1.3
@test config.accept[1, 1, 1] Nworker*1.4
@test config.neval == Nworker * 101
@test config.normalization Nworker * 1.1
@test config.visited[1] Nworker * 1.2
@test config.propose[1, 1, 1] Nworker * 1.3
@test config.accept[1, 1, 1] Nworker * 1.4

@test config.var[1].histogram[1] Nworker*1.1 # X
@test config.var[1].histogram[1] Nworker * 1.1 # X
cvar = config.var[2] #compositevar
@test cvar[1].histogram[1] Nworker *1.2 #Y
@test cvar[2].histogram[1] Nworker*1.3 #Z
@test cvar[1].histogram[1] Nworker * 1.2 #Y
@test cvar[2].histogram[1] Nworker * 1.3 #Z
end
end

Expand All @@ -130,7 +130,7 @@ end
Z.histogram[1] = rank
end

config = Configuration(var = (X, cvar), dof=[[1, 1], ])
config = Configuration(var=(X, cvar), dof=[[1, 1],])
config.reweight = [1.1, 1.2]

MCIntegration.MPIbcastConfig!(config)
Expand All @@ -143,4 +143,27 @@ end
@test cvar[1].histogram[1] 1.2 #Y
@test cvar[2].histogram[1] 1.3 #Z
end
end

@testset "MPI doReweight!" begin
(MPI.Initialized() == false) && MPI.Init()
comm = MPI.COMM_WORLD
Nworker = MPI.Comm_size(comm) # number of MPI workers
rank = MPI.Comm_rank(comm)
root = 0

X = Continuous(0.0, 1.0)
config = Configuration(var=(X,), dof=[[1], [1], [1]])
config.visited = [1, 2, 3, 4]
@test config.reweight == [0.25, 0.25, 0.25, 0.25]

gamma = 1.0
reweight_goal = [1.0, 2.0, 3.0, 4.0]
n_iterations = 5
expected_reweight = [0.25, 0.25, 0.25, 0.25]

for _ in 1:n_iterations
MCIntegration.doReweightMPI!(config, gamma, reweight_goal, comm)
end
@test all(isapprox.(config.reweight, expected_reweight, rtol=1e-3))
end

2 comments on commit 1accc98

@kunyuan
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/81326

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.5 -m "<description of version>" 1accc989ffa91e425f46e998918b92f37f958e5f
git push origin v0.3.5

Please sign in to comment.