From 2252879d42cccd6e7927446df47f5082809af0b4 Mon Sep 17 00:00:00 2001 From: Jerome Jackson Date: Thu, 18 Jul 2024 17:04:49 +0100 Subject: [PATCH] provide a kpoint distribution function --- src/io.F90 | 2 +- src/library_interface.F90 | 59 +++++++++++++++++++++++++++++++++++++++ src/wannier_prog.F90 | 13 ++------- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/src/io.F90 b/src/io.F90 index 713674b5..f445eb4a 100644 --- a/src/io.F90 +++ b/src/io.F90 @@ -437,7 +437,7 @@ subroutine prterr(error, ie, istdout, istderr, comm) write (istderr, *) 'Exiting.......' write (istderr, '(1x,a)') trim(mesg) write (istderr, '(1x,a,i0,a)') '(rank: ', failrank, ')' - write (istderr, '(1x,a)') ' error encountered; check .wout log' + !write (istderr, '(1x,a)') 'error encountered; check .wout log' else ! non 0 ranks je = error%code diff --git a/src/library_interface.F90 b/src/library_interface.F90 index 8dece34c..0c83b561 100644 --- a/src/library_interface.F90 +++ b/src/library_interface.F90 @@ -181,6 +181,8 @@ module w90_library ! this is called by get_nnkp and get_gkpb public :: w90_disentangle !! perform disentanglement + public :: w90_distribute_kpts + !! provides an MPI k-point distribution for codes that don't have one public :: w90_get_centres !! get wannier centers public :: w90_get_fortran_file @@ -1302,4 +1304,61 @@ subroutine w90_set_option_real(common_data, keyword, rval) call expand_settings(common_data%settings) endif endsubroutine w90_set_option_real + + subroutine w90_distribute_kpts(common_data, num_kpts, mpi_size, dist_k, istdout, istderr, ierr) + !! provide a distribution of num_kpts k-points across mpi_size MPI ranks + ! should be called from all ranks in a parallel environment for error propagation + use w90_comms, only: comms_sync_error + use w90_error_base, only: w90_error_type + use w90_error, only: set_error_fatal + + implicit none + + ! arguments + integer, intent(in) :: num_kpts + !! number of k-points + integer, intent(in) :: mpi_size + !! number of ranks in MPI communicator + integer, intent(in) :: istdout, istderr + !! destination for error messages + integer, intent(inout), allocatable :: dist_k(:) + !! already allocated array + !! assigned here such that dist_k(i) = rank handling kpt i + !! size and allocation status are tested + integer, intent(out) :: ierr + !! return code, nonzero in case of error + type(lib_common_type), intent(in) :: common_data + !! library object: only the communicator type is referenced + + ! local variables + type(w90_error_type), allocatable :: error + integer :: ctr, i, nkl + + ierr = 0 + + if (mpi_size < 1) then + call set_error_fatal(error, 'Error: mpi_size < 1 in w90_distribute_kpts call.', common_data%comm) + elseif (num_kpts < 1) then + call set_error_fatal(error, 'Error: num_kpts < 1 in w90_distribute_kpts call.', common_data%comm) + elseif (.not. allocated(dist_k)) then + call set_error_fatal(error, 'Error: dist_k not allocated in w90_distribute_kpts call.', common_data%comm) + elseif (size(dist_k) < num_kpts) then + call set_error_fatal(error, 'Error: size(dist_k) < num_kpts in w90_distribute_kpts call.', common_data%comm) + endif + if (allocated(error)) then + call prterr(error, ierr, istdout, istderr, common_data%comm) + return + endif + + ctr = 0 + do i = 0, mpi_size - 1 + nkl = num_kpts/mpi_size ! number of kpoints per rank + if (mod(num_kpts, mpi_size) > i) nkl = nkl + 1 + if (nkl > 0) then + dist_k(ctr + 1:ctr + nkl) = i + ctr = ctr + nkl + endif + enddo + end subroutine w90_distribute_kpts + end module w90_library diff --git a/src/wannier_prog.F90 b/src/wannier_prog.F90 index 3aabc263..77af85e5 100644 --- a/src/wannier_prog.F90 +++ b/src/wannier_prog.F90 @@ -161,16 +161,9 @@ program wannier write (stderr, *) 'Wannier90: failed to allocate dist_k array!' stop endif - - ctr = 0 - do i = 0, mpisize - 1 - nkl = nk/mpisize ! number of kpoints per rank - if (mod(nk, mpisize) > i) nkl = nkl + 1 - if (nkl > 0) then - dist_k(ctr + 1:ctr + nkl) = i - ctr = ctr + nkl - endif - enddo + ! get a basic k-point/rank distribution + call w90_distribute_kpts(common_data, nk, mpisize, dist_k, stdout, stderr, ierr) + if (ierr /= 0) stop ! copy distribution to library call set_kpoint_distribution(common_data, dist_k, stdout, stderr, ierr)