Skip to content

Commit

Permalink
parallelize write_r2mn; move write_rmn from hamiltonian.F90 to plot.F90
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerome Jackson authored and JeromeCCP9 committed Jan 14, 2025
1 parent 8c40af3 commit bbffae2
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 127 deletions.
102 changes: 0 additions & 102 deletions src/hamiltonian.F90
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ module w90_hamiltonian
public :: hamiltonian_get_hr
public :: hamiltonian_setup
public :: hamiltonian_write_hr
public :: hamiltonian_write_rmn
public :: hamiltonian_write_tb

contains
Expand Down Expand Up @@ -852,107 +851,6 @@ subroutine hamiltonian_wigner_seitz(ws_region, print_output, real_lattice, irvec

end subroutine hamiltonian_wigner_seitz

!================================================!
subroutine hamiltonian_write_rmn(kmesh_info, m_matrix, kpt_latt, irvec, nrpts, num_kpts, &
num_wann, seedname, dist_k, error, comm)
!================================================!
!
!! Write out the matrix elements of r
!
!================================================!

use w90_comms, only: comms_reduce, w90_comm_type, mpisize, mpirank
use w90_constants, only: twopi, cmplx_i
use w90_io, only: io_date
use w90_types, only: kmesh_info_type

implicit none

! arguments
type(kmesh_info_type), intent(in) :: kmesh_info
type(w90_error_type), allocatable, intent(out) :: error
type(w90_comm_type), intent(in) :: comm

integer, intent(inout) :: nrpts
integer, intent(inout) :: irvec(:, :)
integer, intent(in) :: num_wann
integer, intent(in) :: num_kpts
integer, intent(in) :: dist_k(:) ! MPI k-point distribution
real(kind=dp), intent(in) :: kpt_latt(:, :)
complex(kind=dp), intent(in) :: m_matrix(:, :, :, :)
character(len=50), intent(in) :: seedname

! local variables
integer :: loop_rpt, m, n, nkp, ind, nn, file_unit, ierr
integer :: num_nodes, my_node_id, nkp_rank
! nkp_rank is the rank-local kpoint index for m_matrix decomposition
real(kind=dp) :: rdotk
complex(kind=dp) :: fac
complex(kind=dp) :: position(3)
character(len=33) :: header
character(len=9) :: cdate, ctime
logical :: on_root = .false.

num_nodes = mpisize(comm)
my_node_id = mpirank(comm)

if (my_node_id == 0) on_root = .true.

if (on_root) then
open (newunit=file_unit, file=trim(seedname)//'_r.dat', form='formatted', status='unknown', &
iostat=ierr)
if (ierr /= 0) then
call set_error_file(error, 'Error: hamiltonian_write_rmn: problem opening file '//trim(seedname)//'_r', comm)
return
endif

call io_date(cdate, ctime)
header = 'written on '//cdate//' at '//ctime
write (file_unit, *) header ! Date and time
write (file_unit, *) num_wann
write (file_unit, *) nrpts
endif

do loop_rpt = 1, nrpts
do m = 1, num_wann
do n = 1, num_wann

position(:) = 0._dp
nkp_rank = 1
do nkp = 1, num_kpts
if (dist_k(nkp) /= my_node_id) cycle

rdotk = twopi*dot_product(kpt_latt(:, nkp), real(irvec(:, loop_rpt), dp))
fac = exp(-cmplx_i*rdotk)/real(num_kpts, dp)
do ind = 1, 3
do nn = 1, kmesh_info%nntot
if (m .eq. n) then
! For loop_rpt==rpt_origin, this reduces to
! Eq.(32) of Marzari and Vanderbilt PRB 56,
! 12847 (1997). Otherwise, is is Eq.(44)
! Wang, Yates, Souza and Vanderbilt PRB 74,
! 195118 (2006), modified according to
! Eqs.(27,29) of Marzari and Vanderbilt
position(ind) = position(ind) - kmesh_info%wb(nn)*kmesh_info%bk(ind, nn, nkp) &
*aimag(log(m_matrix(n, m, nn, nkp_rank)))*fac
else
! Eq.(44) Wang, Yates, Souza and Vanderbilt PRB 74, 195118 (2006)
position(ind) = position(ind) + cmplx_i*kmesh_info%wb(nn) &
*kmesh_info%bk(ind, nn, nkp)*m_matrix(n, m, nn, nkp_rank)*fac
endif
end do
end do
nkp_rank = nkp_rank + 1
end do ! global k list
call comms_reduce(position(1), 3, 'SUM', error, comm)
if (on_root) write (file_unit, '(5I5,6F12.6)') irvec(:, loop_rpt), n, m, position(:)
end do
end do
end do

if (on_root) close (file_unit)
end subroutine hamiltonian_write_rmn

!================================================!
subroutine hamiltonian_write_tb(ham_logical, kmesh_info, ham_r, m_matrix, kpt_latt, &
real_lattice, irvec, ndegen, nrpts, num_kpts, num_wann, &
Expand Down
164 changes: 139 additions & 25 deletions src/plot.F90
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ subroutine plot_main(atom_data, band_plot, dis_manifold, fermi_energy_list, ferm

use w90_constants, only: eps6, dp
use w90_hamiltonian, only: hamiltonian_get_hr, hamiltonian_write_hr, hamiltonian_setup, &
hamiltonian_write_rmn, hamiltonian_write_tb
hamiltonian_write_tb
use w90_io, only: io_stopwatch_start, io_stopwatch_stop
use w90_types, only: kmesh_info_type, wannier_data_type, atom_data_type, dis_manifold_type, &
kpoint_path_type, print_output_type, ws_region_type, ws_distance_type, timer_list_type, &
Expand Down Expand Up @@ -174,7 +174,7 @@ subroutine plot_main(atom_data, band_plot, dis_manifold, fermi_energy_list, ferm
! Print the header only if there is something to plot
if (w90_calculation%bands_plot .or. w90_calculation%fermi_surface_plot .or. &
output_file%write_hr .or. w90_calculation%wannier_plot .or. output_file%write_u_matrices &
.or. output_file%write_tb .or. output_file%write_rmn) then
.or. output_file%write_tb .or. output_file%write_rmn .or. output_file%write_r2mn) then
write (stdout, '(1x,a)') '*---------------------------------------------------------------------------*'
write (stdout, '(1x,a)') '| PLOTTING |'
write (stdout, '(1x,a)') '*---------------------------------------------------------------------------*'
Expand All @@ -191,12 +191,6 @@ subroutine plot_main(atom_data, band_plot, dis_manifold, fermi_energy_list, ferm
! if (allocated(error)) return
! endif

! write matrix elements <m|r^2|n> to file
if (output_file%write_r2mn) then
call plot_write_r2mn(num_kpts, num_wann, kmesh_info, m_matrix, error, comm, seedname)
if (allocated(error)) return
endif

if (w90_calculation%fermi_surface_plot) then
call plot_fermi_surface(fermi_energy_list, recip_lattice, fermi_surface_plot, num_wann, &
ham_r, irvec, ndegen, nrpts, print_output%timing_level, stdout, &
Expand Down Expand Up @@ -280,8 +274,14 @@ subroutine plot_main(atom_data, band_plot, dis_manifold, fermi_energy_list, ferm

if (output_file%write_rmn) then
! parallel write_rmn
call hamiltonian_write_rmn(kmesh_info, m_matrix, kpt_latt, irvec, nrpts, num_kpts, &
num_wann, seedname, dist_k, error, comm)
call plot_write_rmn(kmesh_info, m_matrix, kpt_latt, irvec, nrpts, num_kpts, num_wann, &
seedname, dist_k, error, comm)
if (allocated(error)) return
endif

if (output_file%write_r2mn) then
! write matrix elements <m|r^2|n> to file
call plot_write_r2mn(num_kpts, num_wann, kmesh_info, m_matrix, seedname, dist_k, error, comm)
if (allocated(error)) return
endif

Expand Down Expand Up @@ -2380,7 +2380,109 @@ subroutine plot_bvec(kmesh_info, num_kpts, seedname, error, comm)
end subroutine plot_bvec

!================================================!
subroutine plot_write_r2mn(num_kpts, num_wann, kmesh_info, m_matrix, error, comm, seedname)
subroutine plot_write_rmn(kmesh_info, m_matrix, kpt_latt, irvec, nrpts, num_kpts, &
num_wann, seedname, dist_k, error, comm)
!================================================!
!
!! Write out the matrix elements of r
!
!================================================!

use w90_comms, only: comms_reduce, w90_comm_type, mpisize, mpirank
use w90_constants, only: twopi, cmplx_i, dp
use w90_error, only: w90_error_type, set_error_file
use w90_io, only: io_date
use w90_types, only: kmesh_info_type

implicit none

! arguments
type(kmesh_info_type), intent(in) :: kmesh_info
type(w90_error_type), allocatable, intent(out) :: error
type(w90_comm_type), intent(in) :: comm

integer, intent(inout) :: nrpts
integer, intent(inout) :: irvec(:, :)
integer, intent(in) :: num_wann
integer, intent(in) :: num_kpts
integer, intent(in) :: dist_k(:) ! MPI k-point distribution
real(kind=dp), intent(in) :: kpt_latt(:, :)
complex(kind=dp), intent(in) :: m_matrix(:, :, :, :)
character(len=50), intent(in) :: seedname

! local variables
integer :: loop_rpt, m, n, nkp, ind, nn, file_unit, ierr
integer :: num_nodes, my_node_id, nkp_rank
! nkp_rank is the rank-local kpoint index for m_matrix decomposition
real(kind=dp) :: rdotk
complex(kind=dp) :: fac
complex(kind=dp) :: position(3)
character(len=33) :: header
character(len=9) :: cdate, ctime
logical :: on_root = .false.

num_nodes = mpisize(comm)
my_node_id = mpirank(comm)

if (my_node_id == 0) on_root = .true.

if (on_root) then
open (newunit=file_unit, file=trim(seedname)//'_r.dat', form='formatted', status='unknown', &
iostat=ierr)
if (ierr /= 0) then
call set_error_file(error, 'Error: hamiltonian_write_rmn: problem opening file '//trim(seedname)//'_r', comm)
return
endif

call io_date(cdate, ctime)
header = 'written on '//cdate//' at '//ctime
write (file_unit, *) header ! Date and time
write (file_unit, *) num_wann
write (file_unit, *) nrpts
endif

do loop_rpt = 1, nrpts
do m = 1, num_wann
do n = 1, num_wann

position(:) = 0._dp
nkp_rank = 1
do nkp = 1, num_kpts
if (dist_k(nkp) /= my_node_id) cycle

rdotk = twopi*dot_product(kpt_latt(:, nkp), real(irvec(:, loop_rpt), dp))
fac = exp(-cmplx_i*rdotk)/real(num_kpts, dp)
do ind = 1, 3
do nn = 1, kmesh_info%nntot
if (m .eq. n) then
! For loop_rpt==rpt_origin, this reduces to
! Eq.(32) of Marzari and Vanderbilt PRB 56,
! 12847 (1997). Otherwise, is is Eq.(44)
! Wang, Yates, Souza and Vanderbilt PRB 74,
! 195118 (2006), modified according to
! Eqs.(27,29) of Marzari and Vanderbilt
position(ind) = position(ind) - kmesh_info%wb(nn)*kmesh_info%bk(ind, nn, nkp) &
*aimag(log(m_matrix(n, m, nn, nkp_rank)))*fac
else
! Eq.(44) Wang, Yates, Souza and Vanderbilt PRB 74, 195118 (2006)
position(ind) = position(ind) + cmplx_i*kmesh_info%wb(nn) &
*kmesh_info%bk(ind, nn, nkp)*m_matrix(n, m, nn, nkp_rank)*fac
endif
end do
end do
nkp_rank = nkp_rank + 1
end do ! global k list
call comms_reduce(position(1), 3, 'SUM', error, comm)
if (on_root) write (file_unit, '(5I5,6F12.6)') irvec(:, loop_rpt), n, m, position(:)
end do
end do
end do

if (on_root) close (file_unit)
end subroutine plot_write_rmn

!================================================!
subroutine plot_write_r2mn(num_kpts, num_wann, kmesh_info, m_matrix, seedname, dist_k, error, comm)
!================================================!
!
! Write seedname.r2mn file
Expand All @@ -2399,41 +2501,53 @@ subroutine plot_write_r2mn(num_kpts, num_wann, kmesh_info, m_matrix, error, comm
type(w90_comm_type), intent(in) :: comm

integer, intent(in) :: num_kpts, num_wann
integer, intent(in) :: dist_k(:) ! MPI k-point distribution
complex(kind=dp), intent(in) :: m_matrix(:, :, :, :)
character(len=50), intent(in) :: seedname

integer :: r2mnunit, nw1, nw2, nkp, nn, ierr
integer :: nkp_rank, my_node_id
real(kind=dp) :: r2ave_mn, delta
logical :: on_root = .false.

! note that here I use formulas analogue to Eq. 23, and not to the
! shift-invariant Eq. 32 .
open (newunit=r2mnunit, file=trim(seedname)//'.r2mn', form='formatted', iostat=ierr)
if (ierr /= 0) then
call set_error_file(error, 'Error opening file '//trim(seedname)//'.r2mn in plot_write_r2mn', comm)
return

my_node_id = mpirank(comm)

if (my_node_id == 0) on_root = .true.

if (on_root) then
open (newunit=r2mnunit, file=trim(seedname)//'.r2mn', form='formatted', iostat=ierr)
if (ierr /= 0) then
call set_error_file(error, 'Error opening file '//trim(seedname)//'.r2mn in plot_write_r2mn', comm)
return
endif
endif

do nw1 = 1, num_wann
do nw2 = 1, num_wann
r2ave_mn = 0.0_dp
delta = 0.0_dp
if (nw1 .eq. nw2) delta = 1.0_dp
nkp_rank = 1
do nkp = 1, num_kpts
if (dist_k(nkp) /= my_node_id) cycle

do nn = 1, kmesh_info%nntot
! [GP-begin, Apr13, 2012: corrected sign inside "real"]
r2ave_mn = r2ave_mn + kmesh_info%wb(nn)* &
! [GP-begin, Apr13, 2012: corrected sign inside "real"]
(2.0_dp*delta - real(m_matrix(nw1, nw2, nn, nkp) + &
conjg(m_matrix(nw2, nw1, nn, nkp)), kind=dp))
! [GP-end]
(2.0_dp*delta - real(m_matrix(nw1, nw2, nn, nkp_rank) + &
conjg(m_matrix(nw2, nw1, nn, nkp_rank)), dp))
enddo
enddo
nkp_rank = nkp_rank + 1
enddo ! global k list
call comms_reduce(r2ave_mn, 1, 'SUM', error, comm)
r2ave_mn = r2ave_mn/real(num_kpts, dp)
write (r2mnunit, '(2i6,f20.12)') nw1, nw2, r2ave_mn
if (on_root) write (r2mnunit, '(2i6,f20.12)') nw1, nw2, r2ave_mn
enddo
enddo
close (r2mnunit)

return

if (on_root) close (r2mnunit)
end subroutine plot_write_r2mn

!================================================!
Expand Down

0 comments on commit bbffae2

Please sign in to comment.