Skip to content

Commit

Permalink
use matrix names more consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerome Jackson committed Jan 15, 2025
1 parent dcf11bf commit 24bdff4
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 113 deletions.
2 changes: 1 addition & 1 deletion src/io.F90
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ subroutine prterr(error, ie, istdout, istderr, comm)
mesg = 'not set'

if (mpirank(comm) == 0) then
! fixme, report all failing ranks instead of lowest failing rank (current stand)
! currently this printout will list only the lowest failing rank, not all failing ranks
do j = mpisize(comm) - 1, 1, -1
call comms_no_sync_recv(je, 1, j, le, comm)

Expand Down
110 changes: 40 additions & 70 deletions src/library_extra.F90
Original file line number Diff line number Diff line change
Expand Up @@ -118,33 +118,38 @@ subroutine input_reader_special(common_data, seedname, istdout, istderr, ierr)
if (common_data%num_bands > common_data%num_wann) then
allocate (common_data%dis_manifold%ndimwin(common_data%num_kpts), stat=ierr)
if (ierr /= 0) then
call set_error_alloc(error, 'Error allocating ndimwin in input_reader_special() call', common_data%comm)
call set_error_alloc(error, &
'Error allocating ndimwin in input_reader_special() call', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
allocate (common_data%dis_manifold%nfirstwin(common_data%num_kpts), stat=ierr)
if (ierr /= 0) then
call set_error_alloc(error, 'Error allocating nfirstwin in input_reader_special() call', common_data%comm)
call set_error_alloc(error, &
'Error allocating nfirstwin in input_reader_special() call', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
allocate (common_data%dis_manifold%lwindow(common_data%num_bands, common_data%num_kpts), stat=ierr)
if (ierr /= 0) then
call set_error_alloc(error, 'Error allocating lwindow in input_reader_special() call', common_data%comm)
call set_error_alloc(error, &
'Error allocating lwindow in input_reader_special() call', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
endif
allocate (common_data%wannier_data%centres(3, common_data%num_wann), stat=ierr)
if (ierr /= 0) then
call set_error_alloc(error, 'Error allocating wannier_centres in input_reader_special() call', common_data%comm)
call set_error_alloc(error, &
'Error allocating wannier_centres in input_reader_special() call', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
common_data%wannier_data%centres = 0.0_dp
allocate (common_data%wannier_data%spreads(common_data%num_wann), stat=ierr)
if (ierr /= 0) then
call set_error_alloc(error, 'Error in allocating wannier_spreads in input_reader_special() call', common_data%comm)
call set_error_alloc(error, &
'Error in allocating wannier_spreads in input_reader_special() call', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
Expand Down Expand Up @@ -179,8 +184,10 @@ subroutine write_kmesh(common_data, istdout, istderr, ierr)

ierr = 0

if (mpirank(common_data%comm) == 0) then
if (.not. allocated(common_data%kmesh_info%nnlist)) call w90_create_kmesh(common_data, istdout, istderr, ierr)
if (mpirank(common_data%comm) == 0) then ! root only
if (.not. allocated(common_data%kmesh_info%nnlist)) then
call w90_create_kmesh(common_data, istdout, istderr, ierr)
endif

call kmesh_write(common_data%exclude_bands, common_data%kmesh_info, &
common_data%select_proj%auto_projections, common_data%proj_input, &
Expand Down Expand Up @@ -214,14 +221,13 @@ subroutine overlaps(common_data, istdout, istderr, ierr)
ierr = 0

if (.not. common_data%setup_complete) then
call set_error_fatal(error, 'kmesh is not setup before calling overlaps read routine', common_data%comm)
call set_error_fatal(error, 'Error: kmesh is not setup before reading overlap matrix', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
if (ierr > 0) return
endif

! projections are stored in u_opt
call overlap_read(common_data%kmesh_info, common_data%select_proj, common_data%u_opt, &
call overlap_read(common_data%kmesh_info, common_data%select_proj, common_data%u_matrix_opt, &
common_data%m_matrix_local, common_data%num_bands, common_data%num_kpts, &
common_data%num_proj, common_data%num_wann, common_data%print_output, &
common_data%print_output%timing_level, cp_pp, common_data%use_bloch_phases, &
Expand All @@ -231,33 +237,6 @@ subroutine overlaps(common_data, istdout, istderr, ierr)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif

! Check Mmn(k,b) is symmetric in m and n for gamma_only case
! if (gamma_only) call overlap_check_m_symmetry()
!
! If we don't need to disentangle we can now convert from A to U
! And rotate M accordingly
! Jan 2023, Jerome Jackson, moved overlap_project outside of overlap_read
! if ((.not. disentanglement) .and. (.not. cp_pp) .and. (.not. use_bloch_phases)) then
! if (.not. gamma_only) then
! call overlap_project(sitesym, m_matrix_local, au_matrix, kmesh_info%nnlist, &
! kmesh_info%nntot, num_bands, num_kpts, num_wann, timing_level, &
! lsitesymmetry, stdout, timer, dist_k, error, comm)
! else
! call overlap_project_gamma(m_matrix_local, au_matrix, kmesh_info%nntot, num_wann, &
! timing_level, stdout, timer, error, comm)
! endif
! if (allocated(error)) return
! endif
!
!~[aam]
!~ if( gamma_only .and. use_bloch_phases ) then
!~ write(stdout,'(1x,"+",76("-"),"+")')
!~ write(stdout,'(3x,a)') 'WARNING: gamma_only and use_bloch_phases '
!~ write(stdout,'(3x,a)') ' M must be calculated from *real* Bloch functions'
!~ write(stdout,'(1x,"+",76("-"),"+")')
!~ end if
![ysl-e]
end subroutine overlaps

subroutine read_eigvals(common_data, eigval, istdout, istderr, ierr)
Expand All @@ -279,11 +258,13 @@ subroutine read_eigvals(common_data, eigval, istdout, istderr, ierr)
ierr = 0

if (size(eigval, 1) /= common_data%num_bands) then
call set_error_fatal(error, 'eigval not dimensioned correctly (num_bands,num_kpts) in read_eigvals', common_data%comm)
call set_error_fatal(error, &
'Error: eigval not dimensioned correctly (num_bands,num_kpts) in read_eigvals', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
elseif (size(eigval, 2) /= common_data%num_kpts) then
call set_error_fatal(error, 'eigval not dimensioned correctly (num_bands,num_kpts) in read_eigvals', common_data%comm)
call set_error_fatal(error, &
'Error: eigval not dimensioned correctly (num_bands,num_kpts) in read_eigvals', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
Expand All @@ -294,7 +275,8 @@ subroutine read_eigvals(common_data, eigval, istdout, istderr, ierr)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
else if (.not. eig_found) then
call set_error_fatal(error, 'failed to read eigenvalues file in read_eigvals call', common_data%comm)
call set_error_fatal(error, &
'Error: failed to read eigenvalues file in read_eigvals', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
Expand All @@ -315,71 +297,69 @@ subroutine write_chkpt(common_data, label, istdout, istderr, ierr)
type(lib_common_type), target, intent(in) :: common_data

! local variables
complex(kind=dp), allocatable :: u(:, :, :), uopt(:, :, :), m(:, :, :, :)
complex(kind=dp), allocatable :: m(:, :, :, :)
integer, allocatable :: global_k(:)
integer, pointer :: nw, nb, nk, nn
integer :: rank, nkrank, ikg, ikl, istat
type(w90_error_type), allocatable :: error

ierr = 0
rank = mpirank(common_data%comm)
nkrank = count(common_data%dist_kpoints == rank)

nb => common_data%num_bands
nk => common_data%num_kpts
nn => common_data%kmesh_info%nntot
nw => common_data%num_wann

if (.not. associated(common_data%u_opt)) then
call set_error_fatal(error, 'u_opt not set for write_chkpt call', common_data%comm)
if (.not. associated(common_data%u_matrix_opt)) then
call set_error_fatal(error, &
'Error: u_matrix_opt not associated for write_chkpt call', common_data%comm)
else if (.not. associated(common_data%u_matrix)) then
call set_error_fatal(error, 'u_matrix not set for write_chkpt call', common_data%comm)
call set_error_fatal(error, &
'Error: u_matrix not associated for write_chkpt call', common_data%comm)
else if (.not. associated(common_data%m_matrix_local)) then
call set_error_fatal(error, 'm_matrix_local not set for write_chkpt call', common_data%comm)
call set_error_fatal(error, &
'Error: m_matrix_local not set for write_chkpt call', common_data%comm)
endif
if (allocated(error)) then
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif

nkrank = count(common_data%dist_kpoints == rank)
allocate (global_k(nkrank), stat=istat)
if (istat /= 0) then
call set_error_alloc(error, 'Error allocating global_k in write_chkpt', common_data%comm)
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
global_k = huge(1); ikl = 1

global_k = huge(1)
ikl = 1
do ikg = 1, nk
if (rank == common_data%dist_kpoints(ikg)) then
global_k(ikl) = ikg
ikl = ikl + 1
endif
enddo

! reassemble full m matrix by MPI reduction
!
! allocating and partially assigning the full matrix on all ranks and reducing is a terrible idea
! alternatively, allocate on root and use point-to-point
! or, if required only for checkpoint file writing, then use mpi-io (but needs to be ordered io, alas)
! or, even better, use parallel hdf5. JJ Nov 22
allocate (u(nw, nw, nk), stat=istat) ! all kpts
if (istat /= 0) call set_error_alloc(error, 'Error allocating u in write_chkpt', common_data%comm)
allocate (uopt(nb, nw, nk), stat=istat) ! all kpts
if (istat /= 0) call set_error_alloc(error, 'Error allocating uopt in write_chkpt', common_data%comm)
allocate (m(nw, nw, nn, nk), stat=istat) ! all kpts
if (istat /= 0) call set_error_alloc(error, 'Error allocating m in write_chkpt', common_data%comm)
if (allocated(error)) then
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif

u(:, :, :) = common_data%u_matrix
uopt(:, :, :) = common_data%u_opt
m(:, :, :, :) = 0.d0

do ikl = 1, nkrank
ikg = global_k(ikl)
m(:, :, :, ikg) = common_data%m_matrix_local(1:nw, 1:nw, :, ikl)
enddo

call comms_reduce(m(1, 1, 1, 1), nw*nw*nn*nk, 'SUM', error, common_data%comm)
if (allocated(error)) then
call prterr(error, ierr, istdout, istderr, common_data%comm)
Expand All @@ -390,27 +370,18 @@ subroutine write_chkpt(common_data, label, istdout, istderr, ierr)
call w90_wannier90_readwrite_write_chkpt(label, common_data%exclude_bands, &
common_data%wannier_data, common_data%kmesh_info, &
common_data%kpt_latt, nk, common_data%dis_manifold, &
nb, nw, u, uopt, m, common_data%mp_grid, &
nb, nw, common_data%u_matrix, &
common_data%u_matrix_opt, m, common_data%mp_grid, &
common_data%real_lattice, &
common_data%omega%invariant, &
common_data%have_disentangled, &
common_data%print_output%iprint, istdout, &
common_data%seedname)
endif

deallocate (u, stat=istat)
if (istat /= 0) then
call set_error_dealloc(error, 'Error deallocating u in write_chkpt', common_data%comm)
endif
deallocate (uopt, stat=istat)
if (istat /= 0) then
call set_error_dealloc(error, 'Error deallocating uopt in write_chkpt', common_data%comm)
endif
deallocate (m, stat=istat)
if (istat /= 0) then
call set_error_dealloc(error, 'Error deallocating m in write_chkpt', common_data%comm)
endif
if (allocated(error)) then
call prterr(error, ierr, istdout, istderr, common_data%comm)
return
endif
Expand Down Expand Up @@ -449,7 +420,6 @@ subroutine read_chkpt(common_data, checkpoint, istdout, istderr, ierr)
! alternatively, allocate on root and use point-to-point
! or, if required only for checkpoint file writing, then use mpi-io (but needs to be ordered io, alas)
! or, even better, use parallel hdf5
! fixme, check allocation status of u, uopt?
allocate (m(nw, nw, nn, nk), stat=istat) ! all kpts
if (istat /= 0) then
call set_error_alloc(error, 'Error allocating m in read_chkpt', common_data%comm)
Expand All @@ -463,7 +433,7 @@ subroutine read_chkpt(common_data, checkpoint, istdout, istderr, ierr)
call w90_readwrite_read_chkpt(common_data%dis_manifold, common_data%exclude_bands, &
common_data%kmesh_info, common_data%kpt_latt, &
common_data%wannier_data, m, common_data%u_matrix, &
common_data%u_opt, common_data%real_lattice, &
common_data%u_matrix_opt, common_data%real_lattice, &
common_data%omega%invariant, common_data%mp_grid, nb, &
nexclude, nk, nw, checkpoint, common_data%have_disentangled, &
ispostw90, common_data%seedname, istdout, error, &
Expand All @@ -476,7 +446,7 @@ subroutine read_chkpt(common_data, checkpoint, istdout, istderr, ierr)

! scatter from m_matrix to m_matrix_local (cf overlap_read)
call w90_readwrite_chkpt_dist(common_data%dis_manifold, common_data%wannier_data, &
common_data%u_matrix, common_data%u_opt, m, &
common_data%u_matrix, common_data%u_matrix_opt, m, &
common_data%m_matrix_local, common_data%omega%invariant, &
nb, nk, nw, nn, checkpoint, common_data%have_disentangled, &
common_data%dist_kpoints, error, common_data%comm)
Expand Down
Loading

0 comments on commit 24bdff4

Please sign in to comment.