Skip to content

Commit

Permalink
wip checkpoint reading
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerome Jackson committed Nov 18, 2022
1 parent 5a3de35 commit c38e419
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 93 deletions.
128 changes: 95 additions & 33 deletions src/library_interface.F90
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ module w90_helper_types
type(transport_type) :: tran
end type lib_w90_type

public:: checkpoint, create_kmesh, get_fortran_stdout, get_fortran_stderr, input_reader, &
overlaps, plot_files, print_times, transport, wannierise, write_kmesh
public:: create_kmesh, get_fortran_stdout, get_fortran_stderr, input_reader, &
overlaps, plot_files, print_times, transport, wannierise, write_kmesh, &
write_chkpt, read_chkpt

public :: set_option
interface set_option
Expand Down Expand Up @@ -209,6 +210,98 @@ subroutine set_option_text(string, text)
call update_settings(string, .false., text, 0.d0, 0)
endsubroutine set_option_text

subroutine write_chkpt(helper, wan90, label, seedname, output, outerr, status, comm)
use w90_wannier90_readwrite, only: w90_wannier90_readwrite_write_chkpt
use w90_comms, only: w90_comm_type, mpirank
implicit none

! arguments
character(len=*), intent(in) :: seedname
character(len=*), intent(in) :: label ! e.g. 'postdis' or 'postwann' after disentanglement, wannierisation
integer, intent(in) :: output, outerr
integer, intent(inout) :: status
type(lib_global_type), intent(inout) :: helper
type(lib_w90_type), intent(in) :: wan90
type(w90_comm_type), intent(in) :: comm

status = 0
if (.not. associated(helper%u_matrix)) then
write (*, *) 'u_matrix not set for write_chkpt call'
write (outerr, *) 'u_matrix not set for write_chkpt call'
status = 1
return
else if (.not. associated(helper%u_opt)) then
write (*, *) 'u_opt not set for write_chkpt call'
write (outerr, *) 'u_opt not set for write_chkpt call'
status = 1
return
else if (.not. associated(wan90%m_matrix)) then
write (*, *) 'm_matrix not set for write_chkpt call'
write (outerr, *) 'm_matrix not set for write_chkpt call'
status = 1
return
endif

if (mpirank(comm) == 0) then
call w90_wannier90_readwrite_write_chkpt(label, helper%exclude_bands, helper%wannier_data, &
helper%kmesh_info, helper%kpt_latt, &
helper%num_kpts, helper%dis_manifold, &
helper%num_bands, helper%num_wann, helper%u_matrix, &
helper%u_opt, wan90%m_matrix, helper%mp_grid, &
helper%real_lattice, wan90%omega%invariant, &
helper%have_disentangled, output, seedname)
endif
end subroutine write_chkpt

subroutine read_chkpt(helper, wan90, checkpoint, seedname, output, outerr, status, comm)
use w90_comms, only: w90_comm_type
use w90_error_base, only: w90_error_type
use w90_readwrite, only: w90_readwrite_read_chkpt_header, w90_readwrite_read_chkpt_matrices
implicit none

! arguments
character(len=*), intent(in) :: seedname
character(len=*), intent(out) :: checkpoint
integer, intent(in) :: output, outerr
integer, intent(out) :: status
type(lib_global_type), intent(inout) :: helper
type(lib_w90_type), intent(inout) :: wan90
type(w90_comm_type), intent(in) :: comm

! local variables
integer :: chk_unit
logical :: have_disentangled, ispostw90 = .false.
type(w90_error_type), allocatable :: error
real(dp) :: omega_invariant

status = 0

call w90_readwrite_read_chkpt_header(helper%exclude_bands, helper%kmesh_info, helper%kpt_latt, &
helper%real_lattice, helper%mp_grid, helper%num_bands, &
size(helper%exclude_bands), helper%num_kpts, &
helper%num_wann, checkpoint, have_disentangled, &
ispostw90, seedname, chk_unit, output, error, comm)
if (allocated(error)) call prterr(error, output, outerr, comm)

call w90_readwrite_read_chkpt_matrices(helper%dis_manifold, helper%kmesh_info, &
helper%wannier_data, wan90%m_matrix, helper%u_matrix, &
helper%u_opt, omega_invariant, helper%num_bands, &
helper%num_kpts, helper%num_wann, have_disentangled, &
seedname, chk_unit, output, error, comm)
if (allocated(error)) call prterr(error, output, outerr, comm)

! scatter from m_matrix_orig to m_matrix_orig_local
! normally achieved in overlap_read
! scatter from m_matrix to m_matrix_local
! w = num_wann*num_wann*kmesh_info%nntot
! call comms_scatterv(m_matrix_local, w*counts(my_node_id), m_matrix, w*counts, w*displs, error, comm)
! if (allocated(error)) call prterr(error, stdout, stderr, comm)

!call w90_readwrite_chkpt_dist(dis_manifold, wannier_data, u_matrix, u_matrix_opt, &
! omega%invariant, num_bands, num_kpts, num_wann, checkpoint, &
! have_disentangled, error, comm)
end subroutine read_chkpt

subroutine input_reader(helper, wan90, seedname, output, outerr, status, comm)
use w90_readwrite, only: w90_readwrite_in_file, w90_readwrite_uppercase, &
w90_readwrite_clean_infile, w90_readwrite_read_final_alloc
Expand Down Expand Up @@ -308,37 +401,6 @@ subroutine input_reader(helper, wan90, seedname, output, outerr, status, comm)

if (mpirank(comm) /= 0) helper%print_output%iprint = 0 ! supress printing non-rank-0
end subroutine input_reader

subroutine checkpoint(helper, wan90, label, output, comm)
use w90_wannier90_readwrite, only: w90_wannier90_readwrite_write_chkpt
use w90_comms, only: w90_comm_type, mpirank

! write_chkpt never fails? remarkable.
! either before or at start of write_chkpt, a reduction on m_matrix is necessary
! fixme JJ check

! fixme, seedname might rather prefer to be an argument

implicit none

character(len=*), intent(in) :: label
integer, intent(in) :: output
type(lib_global_type), intent(inout) :: helper
type(lib_w90_type), intent(in) :: wan90
type(w90_comm_type), intent(in) :: comm

if (mpirank(comm) == 0) then
! e.g. label = 'postwann' after wannierisation
call w90_wannier90_readwrite_write_chkpt(label, helper%exclude_bands, helper%wannier_data, &
helper%kmesh_info, helper%kpt_latt, &
helper%num_kpts, helper%dis_manifold, &
helper%num_bands, helper%num_wann, helper%u_matrix, &
helper%u_opt, wan90%m_matrix, helper%mp_grid, &
helper%real_lattice, wan90%omega%invariant, &
helper%have_disentangled, output, helper%seedname)
endif
end subroutine checkpoint

subroutine create_kmesh(helper, output, outerr, status, comm)
use w90_kmesh, only: kmesh_get
use w90_error_base, only: w90_error_type
Expand Down
63 changes: 16 additions & 47 deletions src/readwrite.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1861,48 +1861,48 @@ subroutine w90_readwrite_read_chkpt(dis_manifold, exclude_bands, kmesh_info, kpt
!================================================!
!! Read checkpoint file
!! This is used to allocate the matrices.
!! If you have them already allocated, e.g. by numpy then call _core directly
!!
!! Note on parallelization: this function should be called
!! from the root node only!
!!
!================================================!

!use w90_constants, only: eps6
use w90_io, only: io_file_unit
use w90_error, only: w90_error_type, set_error_file, set_error_file, set_error_alloc
use w90_utility, only: utility_recip_lattice

implicit none

integer, allocatable, intent(inout) :: exclude_bands(:)
type(wannier_data_type), intent(inout) :: wannier_data
type(kmesh_info_type), intent(in) :: kmesh_info
real(kind=dp), intent(in) :: kpt_latt(:, :)
! arguments
type(dis_manifold_type), intent(inout) :: dis_manifold
type(w90_error_type), allocatable, intent(out) :: error
type(kmesh_info_type), intent(in) :: kmesh_info
type(w90_comm_type), intent(in) :: comm
type(w90_error_type), allocatable, intent(out) :: error
type(wannier_data_type), intent(inout) :: wannier_data

integer, intent(in) :: num_kpts
integer, allocatable, intent(inout) :: exclude_bands(:)
integer, intent(in) :: mp_grid(3)
integer, intent(in) :: num_bands
integer, intent(in) :: num_exclude_bands
integer, intent(in) :: num_kpts
integer, intent(in) :: num_wann
integer, intent(in) :: stdout
integer, intent(in) :: mp_grid(3)
integer, intent(in) :: num_exclude_bands

complex(kind=dp), allocatable, intent(inout) :: u_matrix(:, :, :)
complex(kind=dp), allocatable, intent(inout) :: u_matrix_opt(:, :, :)
complex(kind=dp), intent(inout) :: m_matrix(:, :, :, :)

real(kind=dp), intent(in) :: real_lattice(3, 3)
real(kind=dp), intent(in) :: kpt_latt(:, :)
real(kind=dp), intent(inout) :: omega_invariant
real(kind=dp), intent(in) :: real_lattice(3, 3)

character(len=*), intent(in) :: seedname
character(len=*), intent(inout) :: checkpoint
character(len=*), intent(in) :: seedname

logical, intent(in) :: ispostw90 ! Are we running postw90?
logical, intent(out) :: have_disentangled

! local variables
integer :: chk_unit

call w90_readwrite_read_chkpt_header(exclude_bands, kmesh_info, kpt_latt, real_lattice, &
Expand All @@ -1911,35 +1911,6 @@ subroutine w90_readwrite_read_chkpt(dis_manifold, exclude_bands, kmesh_info, kpt
seedname, chk_unit, stdout, error, comm)
if (allocated(error)) return

! if (have_disentangled) then
! ! U_matrix_opt
! if (.not. allocated(u_matrix_opt)) then
! allocate (u_matrix_opt(num_bands, num_wann, num_kpts), stat=ierr)
! if (ierr /= 0) then
! call set_error_alloc(error, 'Error allocating u_matrix_opt in w90_readwrite_read_chkpt', comm)
! return
! endif
! endif
! endif
!
! ! U_matrix
! if (.not. allocated(u_matrix)) then
! allocate (u_matrix(num_wann, num_wann, num_kpts), stat=ierr)
! if (ierr /= 0) then
! call set_error_alloc(error, 'Error allocating u_matrix in w90_readwrite_read_chkpt', comm)
! return
! endif
! endif
!
! M_matrix
! if (.not. allocated(m_matrix)) then
! allocate (m_matrix(num_wann, num_wann, kmesh_info%nntot, num_kpts), stat=ierr)
! if (ierr /= 0) then
! call set_error_alloc(error, 'Error allocating m_matrix in w90_readwrite_read_chkpt', comm)
! return
! endif
! endif

call w90_readwrite_read_chkpt_matrices(dis_manifold, kmesh_info, wannier_data, m_matrix, &
u_matrix, u_matrix_opt, omega_invariant, num_bands, &
num_kpts, num_wann, have_disentangled, seedname, &
Expand Down Expand Up @@ -2000,8 +1971,7 @@ subroutine w90_readwrite_read_chkpt_header(exclude_bands, kmesh_info, kpt_latt,

write (stdout, '(1x,3a)') 'Reading restart information from file ', trim(seedname), '.chk :'

chk_unit = io_file_unit()
open (unit=chk_unit, file=trim(seedname)//'.chk', status='old', form='unformatted', err=121)
open (newunit=chk_unit, file=trim(seedname)//'.chk', status='old', form='unformatted', err=121)
io_unit = chk_unit

! Read comment line
Expand Down Expand Up @@ -2110,18 +2080,17 @@ subroutine w90_readwrite_read_chkpt_matrices(dis_manifold, kmesh_info, wannier_d
!!
!================================================!

!use w90_constants, only: eps6
use w90_io, only: io_file_unit
use w90_error, only: w90_error_type, set_error_file, set_error_file, set_error_alloc
use w90_utility, only: utility_recip_lattice

implicit none

type(wannier_data_type), intent(inout) :: wannier_data
type(kmesh_info_type), intent(in) :: kmesh_info
type(dis_manifold_type), intent(inout) :: dis_manifold
type(w90_error_type), allocatable, intent(out) :: error
type(kmesh_info_type), intent(in) :: kmesh_info
type(w90_comm_type), intent(in) :: comm
type(w90_error_type), allocatable, intent(out) :: error
type(wannier_data_type), intent(inout) :: wannier_data

integer, intent(in) :: num_kpts
integer, intent(in) :: num_bands
Expand Down
72 changes: 63 additions & 9 deletions src/tiny-lib2-demo.F90
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ program libv2
implicit none

character(len=100) :: seedname
character(len=:), allocatable :: fn
character(len=:), allocatable :: fn, cpstatus
character(len=:), pointer :: restart
complex(kind=dp), allocatable :: a(:, :, :)
!complex(kind=dp), allocatable :: m(:,:,:,:)
complex(kind=dp), allocatable :: mloc(:, :, :, :)
Expand All @@ -30,12 +31,14 @@ program libv2
integer, pointer :: nb, nk, nw, nn
integer :: stdout, stderr
logical, pointer :: pp
logical :: lovlp, ldsnt, lwann, lplot, ltran
type(lib_global_type), target :: w90main
type(lib_w90_type), target :: w90dat
type(w90_comm_type) :: comm
type(w90_error_type), allocatable :: error

pp => w90dat%w90_calculation%postproc_setup
restart => w90dat%w90_calculation%restart
nw => w90main%num_wann
nb => w90main%num_bands
nk => w90main%num_kpts
Expand Down Expand Up @@ -136,18 +139,69 @@ program libv2
call set_u_matrix(w90main, u)
call set_u_opt(w90main, uopt)

call overlaps(w90main, w90dat, stdout, stderr, ierr, comm)
if (ierr /= 0) error stop
! restart system
lovlp = .true.
ldsnt = .true.
lwann = .true.
lplot = .true.
ltran = .true.

if (restart == '') then
if (rank == 0) write (stdout, '(1x,a/)') 'Starting a new Wannier90 calculation ...'
else
call read_chkpt(w90main, w90dat, cpstatus, fn, stdout, stderr, ierr, comm)
if (restart == 'wannierise' .or. (restart == 'default' .and. cpstatus == 'postdis')) then
if (rank == 0) write (stdout, '(1x,a/)') 'Restarting Wannier90 from wannierisation ...'
lovlp = .false.
ldsnt = .false.
lwann = .true.
lplot = .true.
ltran = .true.
elseif (restart == 'plot' .or. (restart == 'default' .and. cpstatus == 'postwann')) then
if (rank == 0) write (stdout, '(1x,a/)') 'Restarting Wannier90 from plotting routines ...'
lovlp = .false.
ldsnt = .false.
lwann = .false.
lplot = .true.
ltran = .true.
elseif (restart == 'transport') then
if (rank == 0) write (stdout, '(1x,a/)') 'Restarting Wannier90 from transport routines ...'
lovlp = .false.
ldsnt = .false.
lwann = .false.
lplot = .false.
ltran = .true.
!else
! illegitimate restart choice
endif
endif
! end restart system

if (nw < nb) then ! disentanglement reqired
call disentangle(w90main, w90dat, stdout, stderr, ierr, comm)
if (lovlp) then
call overlaps(w90main, w90dat, stdout, stderr, ierr, comm)
if (ierr /= 0) error stop
endif
call wannierise(w90main, w90dat, stdout, stderr, ierr, comm)
if (ierr /= 0) error stop

call plot_files(w90main, w90dat, stdout, stderr, ierr, comm)
if (ierr /= 0) error stop
if (ldsnt) then
if (nw < nb) then ! disentanglement reqired
call disentangle(w90main, w90dat, stdout, stderr, ierr, comm)
if (ierr /= 0) error stop
!call write_chkpt(w90main, w90dat, 'postdis', fn, stdout, stderr, ierr, comm)
!if (ierr /= 0) error stop
endif
endif

if (lwann) then
call wannierise(w90main, w90dat, stdout, stderr, ierr, comm)
if (ierr /= 0) error stop
!call write_chkpt(w90main, w90dat, 'postwann', fn, stdout, stderr, ierr, comm)
!if (ierr /= 0) error stop
endif

if (lplot) then
call plot_files(w90main, w90dat, stdout, stderr, ierr, comm)
if (ierr /= 0) error stop
endif

call print_times(w90main, stdout)
if (rank == 0) close (unit=stderr, status='delete')
Expand Down
Loading

0 comments on commit c38e419

Please sign in to comment.