Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes issues identified by ruff and flake8 in pbe.py #56

Merged
merged 9 commits into from
Dec 13, 2024
42 changes: 19 additions & 23 deletions src/quemb/kbe/pbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from quemb.kbe.misc import print_energy, storePBE
from quemb.kbe.pfrag import Frags
from quemb.molbe._opt import BEOPT
from quemb.molbe.helper import get_eri, get_scfObj, get_veff
from quemb.molbe.be_parallel import be_func_parallel
from quemb.molbe.solver import be_func

from quemb.shared import be_var
from quemb.shared.external.optqn import (
get_be_error_jacobian as _ext_get_be_error_jacobian,
Expand Down Expand Up @@ -485,41 +487,38 @@ def initialize(self, compute_hf, restart=False):
"""
if compute_hf:
E_hf = 0.0
EH1 = 0.0
ECOUL = 0.0
EF = 0.0

# Create a file to store ERIs
if not restart:
file_eri = h5py.File(self.eri_file, "w")
lentmp = len(self.edge_idx)
transform_parallel = False # hard set for now
for I in range(self.Nfrag):
for fidx in range(self.Nfrag):
if lentmp:
fobjs_ = Frags(
self.fsites[I],
I,
edge=self.edge[I],
self.fsites[fidx],
fidx,
edge=self.edge[fidx],
eri_file=self.eri_file,
center=self.center[I],
edge_idx=self.edge_idx[I],
center_idx=self.center_idx[I],
efac=self.ebe_weight[I],
centerf_idx=self.centerf_idx[I],
center=self.center[fidx],
edge_idx=self.edge_idx[fidx],
center_idx=self.center_idx[fidx],
efac=self.ebe_weight[fidx],
centerf_idx=self.centerf_idx[fidx],
unitcell=self.unitcell,
unitcell_nkpt=self.unitcell_nkpt,
)
else:
fobjs_ = Frags(
self.fsites[I],
I,
self.fsites[fidx],
fidx,
edge=[],
center=[],
eri_file=self.eri_file,
edge_idx=[],
center_idx=[],
centerf_idx=[],
efac=self.ebe_weight[I],
efac=self.ebe_weight[fidx],
unitcell=self.unitcell,
unitcell_nkpt=self.unitcell_nkpt,
)
Expand Down Expand Up @@ -656,9 +655,8 @@ def initialize(self, compute_hf, restart=False):
* 2.0
)

# energy
if compute_hf:
eh1, ecoul, ef = self.Fobjs[frg].energy_hf(return_e1=True)
self.Fobjs[frg].update_ebe_hf() # Updates fragment HF energy.
E_hf += self.Fobjs[frg].ebe_hf

print(flush=True)
Expand Down Expand Up @@ -785,8 +783,8 @@ def update_fock(self, heff=None):
for fobj in self.Fobjs:
fobj.fock += fobj.heff
else:
for idx, fobj in self.Fobjs:
fobj.fock += heff[idx]
for fidx, fobj in self.Fobjs:
fobj.fock += heff[fidx]

def write_heff(self, heff_file="bepotfile.h5"):
"""
Expand Down Expand Up @@ -844,8 +842,8 @@ def initialize_pot(Nfrag, edge_idx):
pot_ = []

if not len(edge_idx) == 0:
for I in range(Nfrag):
ShaunWeatherly marked this conversation as resolved.
Show resolved Hide resolved
for i in edge_idx[I]:
for fidx in range(Nfrag):
for i in edge_idx[fidx]:
for j in range(len(i)):
for k in range(len(i)):
if j > k:
Expand Down Expand Up @@ -881,7 +879,6 @@ def parallel_fock_wrapper(dname, nao, dm, S, TA, hf_veff, eri_file):
"""
Wrapper for parallel Fock transformation
"""
from .helper import get_eri, get_veff

eri_ = get_eri(dname, nao, eri_file=eri_file, ignore_symm=True)
veff0, veff_ = get_veff(eri_, dm, S, TA, hf_veff, return_veff0=True)
Expand All @@ -893,7 +890,6 @@ def parallel_scf_wrapper(dname, nao, nocc, h1, dm_init, eri_file):
"""
Wrapper for performing fragment scf calculation
"""
from .helper import get_eri, get_scfObj

eri = get_eri(dname, nao, eri_file=eri_file)
mf_ = get_scfObj(h1, eri, nocc, dm_init)
Expand Down
2 changes: 1 addition & 1 deletion src/quemb/kbe/pfrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def set_udim(self, cout):
cout += 1
return cout

def energy_hf(
def update_ebe_hf(
self, rdm_hf=None, mo_coeffs=None, eri=None, return_e1=False, unrestricted=False
):
if mo_coeffs is None:
Expand Down
2 changes: 1 addition & 1 deletion src/quemb/molbe/mbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def initialize(self, eri_, compute_hf, restart=False):
)

if compute_hf:
_, _, _ = fobjs_.energy_hf(return_e1=True) # eh1, ecoul, ef
fobjs_.update_ebe_hf() # Updates fragment HF energy.
E_hf += fobjs_.ebe_hf

if not restart:
Expand Down
10 changes: 5 additions & 5 deletions src/quemb/molbe/pfrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,12 @@ def set_udim(self, cout):
cout += 1
return cout

def energy_hf(
def update_ebe_hf(
self,
rdm_hf=None,
mo_coeffs=None,
eri=None,
return_e1=False,
return_e=False,
unrestricted=False,
spin_ind=None,
):
Expand Down Expand Up @@ -380,12 +380,12 @@ def energy_hf(

self.ebe_hf = etmp

if return_e1:
if return_e:
e_h1 = 0.0
e_coul = 0.0
for i in self.efac[1]:
e_h1 += self.efac[0] * e1[i]
e_coul += self.efac[0] * (e2[i] + ec[i])
return (e_h1, e_coul, e1 + e2 + ec)

return e1 + e2 + ec
else:
return None
2 changes: 1 addition & 1 deletion src/quemb/molbe/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def be_func(
veff0=fobj.veff0,
)
total_e = [sum(x) for x in zip(total_e, e_f)]
fobj.energy_hf()
fobj.update_ebe_hf()

if frag_energy or eeval:
Ecorr = sum(total_e)
Expand Down
8 changes: 4 additions & 4 deletions src/quemb/molbe/ube.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def initialize(self, eri_, compute_hf):
)

if compute_hf:
eh1_a, ecoul_a, ef_a = fobj_a.energy_hf(
return_e1=True, unrestricted=True, spin_ind=0
eh1_a, ecoul_a, ef_a = fobj_a.update_ebe_hf(
return_e=True, unrestricted=True, spin_ind=0
)
unused(ef_a)
EH1 += eh1_a
Expand All @@ -336,8 +336,8 @@ def initialize(self, eri_, compute_hf):
)

if compute_hf:
eh1_b, ecoul_b, ef_b = fobj_b.energy_hf(
return_e1=True, unrestricted=True, spin_ind=1
eh1_b, ecoul_b, ef_b = fobj_b.update_ebe_hf(
return_e=True, unrestricted=True, spin_ind=1
)
unused(ef_b)
EH1 += eh1_b
Expand Down
Loading