Skip to content

Commit

Permalink
Merge pull request #21 from mcocdawc/add_fixes_for_BE
Browse files Browse the repository at this point in the history
Small changes to eri_transform
  • Loading branch information
zhcui authored Nov 13, 2024
2 parents 3afd84e + 0311d2d commit 696c536
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 40 deletions.
67 changes: 40 additions & 27 deletions libdmet/basis_transform/eri_transform.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import numpy as np
import scipy.linalg as la
from scipy import fft as scifft
import h5py

Expand All @@ -26,8 +25,8 @@
from pyscf.lib import logger

from libdmet.basis_transform.make_basis import multiply_basis
from libdmet.system.lattice import (get_phase, get_phase_R2k, round_to_FBZ, kpt_member)
from libdmet.utils.misc import mdot, max_abs, add_spin_dim
from libdmet.system.lattice import (get_phase_R2k, round_to_FBZ, kpt_member)
from libdmet.utils.misc import max_abs, add_spin_dim
from libdmet.utils import logger as log

ERI_IMAG_TOL = 1e-6
Expand Down Expand Up @@ -163,7 +162,10 @@ def get_naoaux(gdf):
"""
assert gdf._cderi is not None
with h5py.File(gdf._cderi, 'r') as f:
nkptij = len(f["j3c"])
try: # OM Aug 3 2024: change to accept "j3c-kptij"
nkptij = f["j3c-kptij"].shape[0]
except KeyError:
nkptij= len(f["j3c"])
naux_k_list = []
for k in range(nkptij):
# gdf._cderi['j3c/k_id/seg_id']
Expand Down Expand Up @@ -228,8 +230,11 @@ def load(aux_slice):
# GDF ERI construction
# *****************************************************************************

# OM Aug 3 2024: Change made to get_emb_eri_fast_gdf to allow input of C_ao_emb

def get_emb_eri_fast_gdf(cell, mydf, C_ao_lo=None, basis=None, feri=None,
kscaled_center=None, symmetry=4, max_memory=None,
C_ao_eo=None,
kconserv_tol=KPT_DIFF_TOL, unit_eri=False, swap_idx=None,
t_reversal_symm=True, incore=True, fout="H2.h5"):
"""
Expand Down Expand Up @@ -257,33 +262,43 @@ def get_emb_eri_fast_gdf(cell, mydf, C_ao_lo=None, basis=None, feri=None,
# treat the possible drop of aux-basis at some kpts.
naux = get_naoaux(mydf)

# If C_ao_lo and basis not given, this routine is k2gamma AO transformation
if C_ao_lo is None:
C_ao_lo = np.zeros((nkpts, nao, nao), dtype=np.complex128)
C_ao_lo[:, range(nao), range(nao)] = 1.0 # identity matrix for each k

# add spin dimension for restricted C_ao_lo
if C_ao_lo.ndim == 3:
C_ao_lo = C_ao_lo[np.newaxis]

# possible kpts shift
kscaled = cell.get_scaled_kpts(kpts)
if kscaled_center is not None:
kscaled -= kscaled_center

# basis related
if basis is None:
basis = np.eye(nkpts * nao).reshape(1, nkpts, nao, nkpts * nao)
if basis.shape[0] < C_ao_lo.shape[0]:
basis = add_spin_dim(basis, C_ao_lo.shape[0])
if C_ao_lo.shape[0] < basis.shape[0]:
C_ao_lo = add_spin_dim(C_ao_lo, basis.shape[0])

if unit_eri: # unit ERI for DMFT
C_ao_emb = C_ao_lo / (nkpts**0.75)
if C_ao_eo is None:
# If C_ao_lo and basis not given, this routine is k2gamma AO transformation
if C_ao_lo is None:
C_ao_lo = np.zeros((nkpts, nao, nao), dtype=np.complex128)
C_ao_lo[:, range(nao), range(nao)] = 1.0 # identity matrix for each k

# add spin dimension for restricted C_ao_lo
if C_ao_lo.ndim == 3:
C_ao_lo = C_ao_lo[np.newaxis]

# basis related
if basis is None:
basis = np.eye(nkpts * nao).reshape(1, nkpts, nao, nkpts * nao)
if basis.shape[0] < C_ao_lo.shape[0]:
basis = add_spin_dim(basis, C_ao_lo.shape[0])
if C_ao_lo.shape[0] < basis.shape[0]:
C_ao_lo = add_spin_dim(C_ao_lo, basis.shape[0])

if unit_eri: # unit ERI for DMFT
C_ao_emb = C_ao_lo / (nkpts**0.75)
else:
phase = get_phase_R2k(cell, kpts)
C_ao_emb = multiply_basis(C_ao_lo, get_basis_k(basis, phase)) / (nkpts**(0.75))
else:
phase = get_phase_R2k(cell, kpts)
C_ao_emb = multiply_basis(C_ao_lo, get_basis_k(basis, phase)) / (nkpts**(0.75))
if C_ao_lo is not None:
raise ValueError("Don't pass both `C_ao_lo` and `C_ao_eo`.")
C_ao_eo = np.asarray(C_ao_eo)
if C_ao_eo.ndim == 3:
C_ao_eo = C_ao_eo[np.newaxis]
assert (nkpts, nao) == C_ao_eo.shape[1:3]
C_ao_emb = C_ao_eo / (nkpts**(0.75))

spin, _, _, nemb = C_ao_emb.shape
nemb_pair = nemb * (nemb+1) // 2
res_shape = (spin * (spin+1) // 2, nemb_pair, nemb_pair)
Expand Down Expand Up @@ -1526,8 +1541,6 @@ def convert_eri_to_gdf(eri, norb, fname=None, tol=1e-8):
import pyscf.pbc.scf as pscf
from pyscf.pbc.lib import chkfile
from libdmet.system import lattice
import libdmet.lo.pywannier90 as pywannier90
from libdmet.utils.misc import mdot
np.set_printoptions(3, linewidth=1000)

cell = pgto.Cell()
Expand Down
52 changes: 39 additions & 13 deletions libdmet/basis_transform/test/test_eri_transform_gdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,59 @@
'''

import numpy as np
import scipy.linalg as la
import os, sys

from pyscf import lib
from pyscf.pbc.lib import chkfile
from pyscf.pbc import scf, gto, df
import pytest
from pyscf import ao2mo
from pyscf.pbc import df, gto, scf
from pyscf.pbc.lib import chkfile

from libdmet.lo import iao, pywannier90
from libdmet.basis_transform import eri_transform, make_basis
from libdmet.basis_transform.make_basis import multiply_basis
from libdmet.basis_transform.eri_transform import get_emb_eri_fast_gdf, get_basis_k
from libdmet.system import lattice
from libdmet.basis_transform import make_basis
from libdmet.basis_transform import eri_transform
from libdmet.utils.misc import mdot, max_abs
from libdmet.system.fourier import get_phase_R2k, get_phase
from libdmet.utils import logger as log

import pytest
from libdmet.utils.misc import max_abs, add_spin_dim

np.set_printoptions(4, linewidth=1000, suppress=True)
log.verbose = "DEBUG2"

def _test_ERI(cell, gdf, kpts, C_ao_lo):
# Fast ERI
# Give a very small memory to force small blksize (for test only)

eri_k2gamma = eri_transform.get_emb_eri(cell, gdf, C_ao_lo=C_ao_lo, \
max_memory=0.02, t_reversal_symm=True, symmetry=4)
eri_k2gamma = eri_k2gamma[0]

def test_passing_C_ao_eo(C_ao_lo):
"""Test if `get_emb_eri_fast_gdf` works
when passing a partially transformed `C_ao_eo` object of embedding orbitals
"""
C_ao_lo = C_ao_lo.copy()
nkpts, nao = C_ao_lo.shape[:2]
assert C_ao_lo.ndim == 3
if C_ao_lo.ndim == 3:
C_ao_lo = C_ao_lo[np.newaxis]

basis = np.eye(nkpts * nao).reshape(1, nkpts, nao, nkpts * nao)
if basis.shape[0] < C_ao_lo.shape[0]:
basis = add_spin_dim(basis, C_ao_lo.shape[0])
if C_ao_lo.shape[0] < basis.shape[0]:
C_ao_lo = add_spin_dim(C_ao_lo, basis.shape[0])
phase = get_phase_R2k(cell, kpts)
C_ao_eo = multiply_basis(C_ao_lo, get_basis_k(basis, phase))[0, :, :, :]

direct_calc = get_emb_eri_fast_gdf(
cell, gdf,
C_ao_eo=C_ao_eo,
max_memory=0.02, t_reversal_symm=True, symmetry=4)[0]
assert max_abs(np.asarray(direct_calc) - eri_k2gamma) < 1e-14

test_passing_C_ao_eo(C_ao_lo)




# outcore routine
eri_transform.ERI_SLICE = 3
eri_outcore = eri_transform.get_emb_eri(cell, gdf, C_ao_lo=C_ao_lo, \
Expand All @@ -42,7 +68,7 @@ def _test_ERI(cell, gdf, kpts, C_ao_lo):
assert diff_outcore < 1e-12

# compared to supercell
scell, phase = eri_transform.get_phase(cell, kpts)
scell, phase = get_phase(cell, kpts)
mydf_scell = df.GDF(scell)
nao = cell.nao_nr()
nkpts = len(kpts)
Expand Down

0 comments on commit 696c536

Please sign in to comment.