Skip to content

Commit

Permalink
Initial tests of serial python wrappers with changes to library
Browse files Browse the repository at this point in the history
  • Loading branch information
sstgfbc committed Jan 25, 2024
1 parent 63fc650 commit a86803d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 49 deletions.
39 changes: 15 additions & 24 deletions wrap/mpi-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@
ftn_error = wan90.w90_library.get_fortran_stderr()

data = wan90.w90_library.lib_common_type()
w90data = wan90.w90_library.lib_wannier_type()
#w90data = wan90.w90_library.lib_wannier_type()

wan90.w90_library.set_parallel_comms(data, MPI.COMM_WORLD.py2f())
#data.comm.comm = MPI.COMM_WORLD.py2f()

status = wan90.w90_library.input_reader(data, w90data, "diamond", ftn_output, ftn_error)
#status = wan90.w90_library.input_reader(data, w90data, "cnt55", ftn_output, ftn_error)

exit
status = wan90.w90_library.input_reader_special(data, "diamond", ftn_output, ftn_error)
status = wan90.w90_library.input_reader(data, ftn_output, ftn_error)

if not data.kmesh_info.explicit_nnkpts :
status = wan90.w90_library.create_kmesh(data, ftn_output, ftn_error)
Expand Down Expand Up @@ -69,41 +67,34 @@
#m_matrix = numpy.zeros((data.num_wann, data.num_wann, data.kmesh_info.nntot, data.num_kpts), dtype=numpy.cdouble, order='F')
u_matrix = numpy.zeros((data.num_wann, data.num_wann, data.num_kpts), dtype=numpy.cdouble, order='F')

#wan90.w90_library.set_m_matrix(w90data, m_matrix)
#wan90.w90_library.set_m_matrix(m_matrix)
wan90.w90_library.set_u_matrix(data, u_matrix)

m_matrix_loc = numpy.zeros((data.num_wann, data.num_wann, data.kmesh_info.nntot, counts[my_proc]), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_m_matrix_local(w90data, m_matrix_loc)
m_matrix = numpy.zeros((data.num_bands, data.num_bands, data.kmesh_info.nntot, counts[my_proc]), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_m_orig(data, m_matrix)
u_opt = numpy.zeros((data.num_bands, data.num_wann, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_u_opt(data, u_opt)
status = wan90.w90_library.overlaps(data, ftn_output, ftn_error)

if data.num_wann == data.num_bands:
u_opt = numpy.zeros((1, 1, 1), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_u_opt(data, u_opt)
status = wan90.w90_library.overlaps(data, w90data, ftn_output, ftn_error)
status = wan90.w90_library.projovlp(data, w90data, ftn_output, ftn_error)
status = wan90.w90_library.projovlp(data, ftn_output, ftn_error)
else:
a_matrix = numpy.zeros((data.num_bands, data.num_wann, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_a_matrix(w90data, a_matrix)
u_opt = numpy.zeros((data.num_bands, data.num_wann, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_u_opt(data, u_opt)
m_orig = numpy.zeros((data.num_bands, data.num_bands, data.kmesh_info.nntot, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_m_orig(w90data, m_orig)
status = wan90.w90_library.overlaps(data, w90data, ftn_output, ftn_error)
eigval = numpy.zeros((data.num_bands, data.num_kpts), dtype=numpy.double, order='F')
status = wan90.w90_library.read_eigvals(data, eigval, ftn_output, ftn_error)
wan90.w90_library.set_eigval(data, eigval)
status = wan90.w90_library.disentangle(data, w90data, ftn_output, ftn_error)
status = wan90.w90_library.disentangle(data, ftn_output, ftn_error)
if status == 1:
exit

status = wan90.w90_library.wannierise(data, w90data, ftn_output, ftn_error)
status = wan90.w90_library.wannierise(data, ftn_output, ftn_error)

#wan90.w90_library.checkpoint(data, w90data, "postwann", ftn_output, ftn_error)
#wan90.w90_library.checkpoint(data, "postwann", ftn_output, ftn_error)

#wan90.w90_library.transport(helper, w90data, ftn_output, ftn_error, status)
#wan90.w90_library.transport(helper, ftn_output, ftn_error, status)

#print (data.num_wann)

#wan90.w90_library.plot_files(data, w90data, ftn_output, ftn_error, status)
#wan90.w90_library.plot_files(data, ftn_output, ftn_error, status)

if my_proc == 0:
wan90.w90_library.print_times(data, ftn_output)
38 changes: 13 additions & 25 deletions wrap/serial-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
ftn_error = wan90.w90_library.get_fortran_stderr()

data = wan90.w90_library.lib_common_type()
w90data = wan90.w90_library.lib_wannier_type()
status = wan90.w90_library.input_reader(data, w90data, "diamond", ftn_output, ftn_error)
status = wan90.w90_library.input_reader_special(data, "diamond", ftn_output, ftn_error)
status = wan90.w90_library.input_reader(data, ftn_output, ftn_error)

if not data.kmesh_info.explicit_nnkpts :
status = wan90.w90_library.create_kmesh(data, ftn_output, ftn_error)
Expand All @@ -22,41 +22,29 @@
#displs = numpy.zeros(1, dtype=numpy.int32)
#wan90.w90_library.set_kpoint_block(data, counts, displs)

#m_matrix = numpy.zeros((data.num_wann, data.num_wann, data.kmesh_info.nntot, data.num_kpts), dtype=numpy.cdouble, order='F')
m_matrix = numpy.zeros((data.num_bands, data.num_bands, data.kmesh_info.nntot, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_m_orig(data, m_matrix)
u_opt = numpy.zeros((data.num_bands, data.num_wann, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_u_opt(data, u_opt)
u_matrix = numpy.zeros((data.num_wann, data.num_wann, data.num_kpts), dtype=numpy.cdouble, order='F')
#wan90.w90_library.set_m_matrix(w90data, m_matrix)
wan90.w90_library.set_u_matrix(data, u_matrix)
m_matrix_loc = numpy.zeros((data.num_wann, data.num_wann, data.kmesh_info.nntot, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_m_matrix_local(w90data, m_matrix_loc)

#m_matrix.flags.f_contiguous should be true
status = wan90.w90_library.overlaps(data, ftn_output, ftn_error)

if data.num_wann == data.num_bands:
u_opt = numpy.zeros((1, 1, 1), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_u_opt(data, u_opt)
status = wan90.w90_library.overlaps(data, w90data, ftn_output, ftn_error)
status = wan90.w90_library.projovlp(data, w90data, ftn_output, ftn_error)
status = wan90.w90_library.projovlp(data, ftn_output, ftn_error)
else:
a_matrix = numpy.zeros((data.num_bands, data.num_wann, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_a_matrix(w90data, a_matrix)
u_opt = numpy.zeros((data.num_bands, data.num_wann, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_u_opt(data, u_opt)
m_orig = numpy.zeros((data.num_bands, data.num_bands, data.kmesh_info.nntot, data.num_kpts), dtype=numpy.cdouble, order='F')
wan90.w90_library.set_m_orig(w90data, m_orig)
status = wan90.w90_library.overlaps(data, w90data, ftn_output, ftn_error)
# allocate problem here (and seedname, and ierr not out) + set_eigval() "cnt55"
eigval = numpy.zeros((data.num_bands, data.num_kpts), dtype=numpy.double, order='F')
status = wan90.w90_library.read_eigvals(data, eigval, ftn_output, ftn_error)
wan90.w90_library.set_eigval(data, eigval)
status = wan90.w90_library.disentangle(data, w90data, ftn_output, ftn_error)

status = wan90.w90_library.wannierise(data, w90data, ftn_output, ftn_error)
status = wan90.w90_library.disentangle(data, ftn_output, ftn_error)

status = wan90.w90_library.wannierise(data, ftn_output, ftn_error)

#status = wan90.w90_library.checkpoint(data, w90data, "postwann", ftn_output, ftn_error)
#status = wan90.w90_library.checkpoint(data, "postwann", ftn_output, ftn_error)

status = wan90.w90_library.plot_files(data, w90data, ftn_output, ftn_error)
status = wan90.w90_library.plot_files(data, ftn_output, ftn_error)

#status = wan90.w90_library.transport(data, w90data, ftn_output, ftn_error)
#status = wan90.w90_library.transport(data, ftn_output, ftn_error)

wan90.w90_library.print_times(data, ftn_output)

0 comments on commit a86803d

Please sign in to comment.