Skip to content

Commit

Permalink
Migrated RichardsonLucy slicing and message passing to DataIFWithPara…
Browse files Browse the repository at this point in the history
…llel

Three instances (all pertaining to saving results) remain in RichardsonLucy class.
  • Loading branch information
avalluvan committed Dec 15, 2024
1 parent 860b5ce commit 63527db
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 202 deletions.
232 changes: 40 additions & 192 deletions cosipy/image_deconvolution/RichardsonLucy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from pathlib import Path
import copy
import logging

Expand All @@ -11,14 +10,10 @@
import pandas as pd
import astropy.units as u
from astropy.io import fits
from mpi4py import MPI
from histpy import Histogram, HealpixAxis, Axes
from histpy import Histogram

from .deconvolution_algorithm_base import DeconvolutionAlgorithmBase

# Define MPI variable
MASTER = 0 # Indicates master process

class RichardsonLucy(DeconvolutionAlgorithmBase):
"""
A class for the RichardsonLucy algorithm.
Expand Down Expand Up @@ -75,33 +70,23 @@ def __init__(self, initial_model:Histogram, dataset:list, mask, parameter, comm=
else:
os.makedirs(self.save_results_directory)

# Specific to parallel implementation:
# 1. Assume numproc is known by the process that invoked `run_deconvolution()`
# 2. All processes have loaded event data, background, and created
# initial model (from model properties) independently
self.parallel = False
if comm is not None:
self.comm = comm
if self.comm.Get_size() > 1:
self.parallel = True
logger.info('Image Deconvolution set to run in parallel mode')
if comm is None:
self.MASTER = True
elif comm.Get_rank() == 0:
self.MASTER = True
else:
self.MASTER = False

def initialization(self):
"""
initialization before performing image deconvolution
"""

# Parallel
if self.parallel:
self.numtasks = self.comm.Get_size()
self.taskid = self.comm.Get_rank()

# Master
if (not self.parallel) or (self.parallel and (self.taskid == MASTER)):
if self.MASTER:
# Clear results
self.results.clear()

# All
# clear counter
self.iteration_count = 0

Expand All @@ -111,20 +96,6 @@ def initialization(self):
# calculate exposure map
self.summed_exposure_map = self.calc_summed_exposure_map()

# Parallel
if self.parallel:
'''
Synchronization Barrier 0 (performed only once)
'''
total_exposure_map = np.empty_like(self.summed_exposure_map, dtype=np.float64)

# Gather all arrays into recvbuf
self.comm.Allreduce(self.summed_exposure_map.contents, total_exposure_map, op=MPI.SUM) # For multiple MPI processes, full = [slice1, ... sliceN]

# Reshape the received buffer back into the original array shape
self.summed_exposure_map[:] = total_exposure_map

# All
# mask setting
if self.mask is None and np.any(self.summed_exposure_map.contents == 0):
self.mask = Histogram(self.model.axes, contents = self.summed_exposure_map.contents > 0)
Expand Down Expand Up @@ -153,56 +124,10 @@ def Estep(self):
E-step (but it will be skipped).
"""

# All
# expected count histograms
expectation_list_slice = self.calc_expectation_list(model = self.model, dict_bkg_norm = self.dict_bkg_norm)
self.expectation_list = self.calc_expectation_list(model = self.model, dict_bkg_norm = self.dict_bkg_norm)
logger.info("The expected count histograms were calculated with the initial model map.")

# Serial
if not self.parallel:
self.expectation_list = expectation_list_slice # If single process, then full = slice

# Parallel
elif self.parallel:
'''
Synchronization Barrier 1
'''

self.expectation_list = []
for data, epsilon_slice in zip(self.dataset, expectation_list_slice):
# Gather the sizes of local arrays from all processes
local_size = np.array([epsilon_slice.contents.size], dtype=np.int32)
all_sizes = np.zeros(self.numtasks, dtype=np.int32)
self.comm.Allgather(local_size, all_sizes)

# Calculate displacements
displacements = np.insert(np.cumsum(all_sizes), 0, 0)[0:-1]

# Create a buffer to receive the gathered data
total_size = int(np.sum(all_sizes))
recvbuf = np.empty(total_size, dtype=np.float64) # Receive buffer

# Gather all arrays into recvbuf
self.comm.Allgatherv(epsilon_slice.contents.flatten(), [recvbuf, all_sizes, displacements, MPI.DOUBLE]) # For multiple MPI processes, full = [slice1, ... sliceN]

# Reshape the received buffer back into the original 3D array shape
epsilon = np.concatenate([ recvbuf[displacements[i]:displacements[i] + all_sizes[i]].reshape((-1,) + epsilon_slice.contents.shape[1:]) for i in range(self.numtasks) ], axis=-1)

# Create Histogram that will be appended to self.expectation_list
axes = []
for axis in data.event.axes:
if axis.label == 'PsiChi':
axes.append(HealpixAxis(edges = axis.edges,
label = axis.label,
scale = axis._scale,
coordsys = axis._coordsys,
nside = axis.nside))
else:
axes.append(axis)

# Add to list that manages multiple datasets
self.expectation_list.append(Histogram(Axes(axes), contents=epsilon, unit=data.event.unit)) # TODO: Could maybe be simplified using Histogram.slice[]

# At the end of this function, all processes should have a complete `self.expectation_list`
# to proceed to the Mstep function

Expand All @@ -211,65 +136,15 @@ def Mstep(self):
M-step in RL algorithm.
"""

# All
ratio_list = [ data.event / expectation for data, expectation in zip(self.dataset, self.expectation_list) ]

# delta model
C_slice = self.calc_summed_T_product(ratio_list)

# Serial
if not self.parallel:
sum_T_product = C_slice

# Parallel
elif self.parallel:
'''
Synchronization Barrier 2
'''

# Gather the sizes of local arrays from all processes
local_size = np.array([C_slice.contents.size], dtype=np.int32)
all_sizes = np.zeros(self.numtasks, dtype=np.int32)
self.comm.Allgather(local_size, all_sizes)

# Calculate displacements
displacements = np.insert(np.cumsum(all_sizes), 0, 0)[0:-1]

# Create a buffer to receive the gathered data
total_size = int(np.sum(all_sizes))
recvbuf = np.empty(total_size, dtype=np.float64) # Receive buffer

# Gather all arrays into recvbuf
self.comm.Gatherv(C_slice.contents.value.flatten(), [recvbuf, all_sizes, displacements, MPI.DOUBLE]) # For multiple MPI processes, full = [slice1, ... sliceN]

# Master
if self.taskid == MASTER:
# Reshape the received buffer back into the original 2D array shape
C = np.concatenate([ recvbuf[displacements[i]:displacements[i] + all_sizes[i]].reshape((-1,) + C_slice.contents.shape[1:]) for i in range(self.numtasks) ], axis=0)

# Create Histogram object for sum_T_product
axes = []
for axis in self.model.axes:
if axis.label == 'lb':
axes.append(HealpixAxis(edges = axis.edges,
label = axis.label,
scale = axis._scale,
coordsys = axis._coordsys,
nside = axis.nside))
else:
axes.append(axis)

# C_slice (only slice operated on by current node) --> sum_T_product (all )
sum_T_product = Histogram(Axes(axes), contents=C, unit=C_slice.unit) # TODO: Could maybe be simplified using Histogram.slice[]

# Master
if (not self.parallel) or ((self.parallel) and (self.taskid == MASTER)):
self.delta_model = self.model * (sum_T_product/self.summed_exposure_map - 1)
sum_T_product = self.calc_summed_T_product(ratio_list)
self.delta_model = self.model * (sum_T_product/self.summed_exposure_map - 1)

if self.mask is not None:
self.delta_model = self.delta_model.mask_pixels(self.mask)
if self.mask is not None:
self.delta_model = self.delta_model.mask_pixels(self.mask)

# All
# background normalization optimization
if self.do_bkg_norm_optimization:
for key in self.dict_bkg_norm.keys():
Expand All @@ -286,11 +161,10 @@ def Mstep(self):

self.dict_bkg_norm[key] = bkg_norm

# Alternately, let MASTER node calculate it and broadcast the value
# self.comm.bcast(self.dict_bkg_norm[key], root=MASTER) # This synchronization barrier is not required during the final iteration

# At the end of this function, just the MASTER MPI process needs to have a full
# copy of delta_model
# At the end of this function, all the nodes will have a full
# copy of delta_model. Although this is not necessary, this is
# the easiest way without editing RichardsonLucy.py.
# The same applies for self.dict_bkg_norm

def post_processing(self):
"""
Expand All @@ -300,57 +174,33 @@ def post_processing(self):
- acceleration of RL algirithm: the normalization of delta map is increased as long as the updated image has no non-negative components.
"""

# Master
if (not self.parallel) or ((self.parallel) and (self.taskid == MASTER)):
self.processed_delta_model = copy.deepcopy(self.delta_model)
# All
self.processed_delta_model = copy.deepcopy(self.delta_model)

if self.do_response_weighting:
self.processed_delta_model[:] *= self.response_weighting_filter
if self.do_response_weighting:
self.processed_delta_model[:] *= self.response_weighting_filter

if self.do_smoothing:
self.processed_delta_model = self.processed_delta_model.smoothing(fwhm = self.smoothing_fwhm)

if self.do_acceleration:
self.alpha = self.calc_alpha(self.processed_delta_model, self.model)
else:
self.alpha = 1.0

self.model = self.model + self.processed_delta_model * self.alpha
self.model[:] = np.where(self.model.contents < self.minimum_flux, self.minimum_flux, self.model.contents)

if self.mask is not None:
self.model = self.model.mask_pixels(self.mask)

# update loglikelihood_list
self.loglikelihood_list = self.calc_loglikelihood_list(self.expectation_list)
logger.debug("The loglikelihood list was updated with the new expected count histograms.")

# Parallel
if self.parallel:
'''
Synchronization Barrier 3
'''
# Initialize new variable as MPI only sends values and not units
if self.taskid == MASTER:
buffer = self.model.contents.value
else:
buffer = np.empty(self.model.contents.shape, dtype=np.float64)
if self.do_smoothing:
self.processed_delta_model = self.processed_delta_model.smoothing(fwhm = self.smoothing_fwhm)

if self.do_acceleration:
self.alpha = self.calc_alpha(self.processed_delta_model, self.model)
else:
self.alpha = 1.0

self.comm.Bcast([buffer, MPI.DOUBLE], root=MASTER) # This synchronization barrier is not required during the final iteration
self.model = self.model + self.processed_delta_model * self.alpha
self.model[:] = np.where(self.model.contents < self.minimum_flux, self.minimum_flux, self.model.contents)

if self.taskid > MASTER:
# Reconstruct ModelBase object for self.model
new_model = self.model.__class__(nside = self.model.axes['lb'].nside, # self.model.__class__ will return the Class of which `model` is an object
energy_edges = self.model.axes['Ei'].edges,
scheme = self.model.axes['lb']._scheme,
coordsys = self.model.axes['lb'].coordsys,
unit = self.model.unit)

new_model[:] = buffer * self.model.unit
self.model = new_model
if self.mask is not None:
self.model = self.model.mask_pixels(self.mask)

# update loglikelihood_list
self.loglikelihood_list = self.calc_loglikelihood_list(self.expectation_list)
logger.debug("The loglikelihood list was updated with the new expected count histograms.")

# At the end of this function, all MPI processes needs to have a full
# copy of updated model.
# copy of updated model. They calculate it from delta_model (which is
# distributed by MPI.Bcast) independently

def register_result(self):
"""
Expand All @@ -365,7 +215,7 @@ def register_result(self):
"""

# Master
if (not self.parallel) or ((self.parallel) and (self.taskid == MASTER)):
if self.MASTER:
this_result = {"iteration": self.iteration_count,
"model": copy.deepcopy(self.model),
"delta_model": copy.deepcopy(self.delta_model),
Expand All @@ -391,7 +241,6 @@ def check_stopping_criteria(self):
bool
"""

# All
if self.iteration_count < self.iteration_max:
return False
return True
Expand All @@ -402,7 +251,7 @@ def finalization(self):
"""

# Master
if (not self.parallel) or ((self.parallel) and (self.taskid == MASTER)):
if self.MASTER:
if self.save_results == True:
logger.info('Saving results in {self.save_results_directory}')

Expand Down Expand Up @@ -446,7 +295,6 @@ def calc_alpha(self, delta_model, model):
Acceleration parameter
"""

# Master: Only invoked by master process
diff = -1 * (model / delta_model).contents

diff[(diff <= 0) | (delta_model.contents == 0)] = np.inf
Expand Down
Loading

0 comments on commit 63527db

Please sign in to comment.