Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] clearer trexio eri interface #105

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions pyscf/tools/trexio.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyscf import scf
from pyscf import pbc
from pyscf import fci
from pyscf import ao2mo

import trexio

Expand Down Expand Up @@ -307,7 +308,8 @@ def scf_from_trexio(filename):
mf.mo_occ = mo_occ
return mf

def write_eri(eri, filename, backend='h5'):
def write_eri(eri, filename, backend='h5', basis='mo'):
assert basis.upper() in ['MO','AO']
num_integrals = eri.size
if eri.ndim == 4:
n = eri.shape[0]
Expand All @@ -330,15 +332,31 @@ def write_eri(eri, filename, backend='h5'):
idx = idx[np.tril_indices(npair)]

with trexio.File(filename, 'w', back_end=_mode(backend)) as tf:
trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel())
if basis.upper() == 'MO':
trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel())
else:
trexio.write_ao_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel())

def write_scf_eri(mf, filename, backend='h5', basis='mo'):
assert basis.upper() in ['MO','AO']
if basis.upper() == 'MO':
write_eri(ao2mo.kernel(mf._eri, mf.mo_coeff), filename, backend, basis)
else:
write_eri(mf._eri, filename, backend, basis)

def read_eri(filename):

def read_eri(filename, basis='mo'):
'''Read ERIs in AO basis, 8-fold symmetry is assumed'''
assert basis.upper() in ['MO','AO']
basis_is_mo = (basis.upper() == 'MO')
with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf:
nmo = trexio.read_mo_num(tf)
nao_pair = nmo * (nmo+1) // 2
eri_size = nao_pair * (nao_pair+1) // 2
idx, data, n_read, eof_flag = trexio.read_mo_2e_int_eri(tf, 0, eri_size)
norb = trexio.read_mo_num(tf) if basis_is_mo else trexio.read_ao_num(tf)
norb_pair = norb * (norb+1) // 2
eri_size = norb_pair * (norb_pair+1) // 2
if basis_is_mo:
idx, data, n_read, eof_flag = trexio.read_mo_2e_int_eri(tf, 0, eri_size)
else:
idx, data, n_read, eof_flag = trexio.read_ao_2e_int_eri(tf, 0, eri_size)
eri = np.zeros(eri_size)
x = idx[:,0]*(idx[:,0]+1)//2 + idx[:,1]
y = idx[:,2]*(idx[:,2]+1)//2 + idx[:,3]
Expand Down