Skip to content

Commit

Permalink
final results weights scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
lmiq committed Dec 1, 2023
1 parent 5cc5a7d commit 9819d19
Showing 1 changed file with 54 additions and 30 deletions.
84 changes: 54 additions & 30 deletions src/results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,24 @@ function sphereradiusfromshellvolume(volume, step)
return (0.5 * (volume / fourthirdsofpi + 2 * rmin^3))^(1 / 3)
end

#
# Function to compute the normalization of the weights, given the optional
# frame weights and the number of frames read, which is used in the case
# that frame weights were not provided
#
function sum_frame_weights(R::Result)
i = R.options.firstframe
s = R.options.stride
l = R.lastframe_read
Q = if !isempty(R.options.frame_weights)
sum(R.options.frame_weights[i] for i in i:s:l)
else
# number of frames that were red from the file
round(Int, (l - i + 1) / s)
end
return Q
end

"""
finalresults!(R::Result, options::Options, trajectory::Trajectory)
Expand All @@ -345,16 +363,19 @@ function finalresults!(R::Result, options::Options, trajectory::Trajectory)
R.d[i] = shellradius(i, options.binstep)
end

# Normalization of number of frames: sum of weights for all frames read
Q = sum_frame_weights(R)

# Scale counters by number of samples and frames
@. R.md_count = R.md_count / (R.solute.nmols * R.nframes_read)
@. R.solute_atom = R.solute_atom / (R.solute.nmols * R.nframes_read)
@. R.solvent_atom = R.solvent_atom / (R.solute.nmols * R.nframes_read)
@. R.md_count_random = R.md_count_random / (samples.random * R.nframes_read)
@. R.rdf_count = R.rdf_count / (R.solute.nmols * R.nframes_read)
@. R.rdf_count_random = R.rdf_count_random / (samples.random * R.nframes_read)
@. R.md_count = R.md_count / (R.solute.nmols * Q)
@. R.solute_atom = R.solute_atom / (R.solute.nmols * Q)
@. R.solvent_atom = R.solvent_atom / (R.solute.nmols * Q)
@. R.md_count_random = R.md_count_random / (samples.random * Q)
@. R.rdf_count = R.rdf_count / (R.solute.nmols * Q)
@. R.rdf_count_random = R.rdf_count_random / (samples.random * Q)

# Volume of each bin shell and of the solute domain
R.volume.total = R.volume.total / R.nframes_read
R.volume.total = R.volume.total / Q
@. R.volume.shell = R.volume.total * (R.rdf_count_random / samples.solvent_nmols)

# Solute domain volume
Expand Down Expand Up @@ -444,7 +465,6 @@ of the set provided weighted by the number of frames read in each Result set.
"""
function Base.merge(r::Vector{<:Result})

nr = length(r)
nframes_read = r[1].nframes_read
error = false
Expand All @@ -455,7 +475,7 @@ function Base.merge(r::Vector{<:Result})
"ERROR: To merge Results, the number of bins of the histograms of both sets must be the same.",
)
end
if (r[ir].cutoff - r[1].cutoff) > 1.e-8
if !(r[ir].cutoff r[1].cutoff)
println(
"ERROR: To merge Results, cutoff distance of the of the histograms of both sets must be the same.",
)
Expand All @@ -466,16 +486,16 @@ function Base.merge(r::Vector{<:Result})
end

# List of files and weights
nfiles = 0
for ir = 1:nr
nfiles += length(r[ir].files)
end
nfiles = sum(length(R.files) for R in r)
files = Vector{String}(undef, nfiles)
weights = Vector{Float64}(undef, nfiles)

# Final resuls
# First, merge the options
options = merge(getfield.(r, :options))

# Structure for merged results
R = Result(
options = r[1].options,
options = options,
nbins = r[1].nbins,
dbulk = r[1].dbulk,
cutoff = r[1].cutoff,
Expand All @@ -489,51 +509,42 @@ function Base.merge(r::Vector{<:Result})
weights = weights,
)

# Average results weighting the data considering the number of frames of each data set
# Total normalization factor: sum of the number of frame reads,
# or the sum of frame weights
Q = sum_frame_weights(R)

# Average results weighting the data considering the weights of the frames of each data set
@. R.d = r[1].d
ifile = 0
for ir = 1:nr

w = r[ir].nframes_read / nframes_read

w = sum_frame_weights(r[ir]) / Q
@. R.mddf += w * r[ir].mddf
@. R.kb += w * r[ir].kb

@. R.rdf += w * r[ir].rdf
@. R.kb_rdf += w * r[ir].kb_rdf

@. R.md_count += w * r[ir].md_count
@. R.md_count_random += w * r[ir].md_count_random

@. R.coordination_number += w * r[ir].coordination_number
@. R.coordination_number_random += w * r[ir].coordination_number_random

@. R.solute_atom += w * r[ir].solute_atom
@. R.solvent_atom += w * r[ir].solvent_atom

@. R.rdf_count += w * r[ir].rdf_count
@. R.rdf_count_random += w * r[ir].rdf_count_random

@. R.sum_rdf_count += w * r[ir].sum_rdf_count
@. R.sum_rdf_count_random += w * r[ir].sum_rdf_count_random

R.density.solute += w * r[ir].density.solute
R.density.solvent += w * r[ir].density.solvent
R.density.solvent_bulk += w * r[ir].density.solvent_bulk

R.volume.total += w * r[ir].volume.total
R.volume.bulk += w * r[ir].volume.bulk
R.volume.domain += w * r[ir].volume.domain
R.volume.shell += w * r[ir].volume.shell

for j = 1:length(r[ir].files)
ifile += 1
R.files[ifile] = normpath(r[ir].files[j])
R.weights[ifile] = w * r[ir].weights[j]
end

end

return R
end

Expand Down Expand Up @@ -609,6 +620,19 @@ end

R = merge([R1, R2])
@test isapprox(R, R_save, debug = true)

# Test merging files for which weights are provided for the frames
traj = Trajectory("$dir/trajectory.dcd", tmao, water)
options = Options(firstframe = 1, lastframe = 2, seed = 321, StableRNG = true, nthreads = 1, silent = true, frame_weights = fill(0.3, 20))
R1 = mddf(traj, options)

# First lets test the error message, in case the frame_weights are not provided for all frames
options = Options(firstframe = 1, lastframe = 2, seed = 321, StableRNG = true, nthreads = 1, silent = true)
R2 = mddf(traj, options)
@test_throws ArgumentError merge([R1, R2])



end

@testitem "Result - empty" begin
Expand Down

0 comments on commit 9819d19

Please sign in to comment.