Skip to content

Commit

Permalink
Dataclass for solver kwargs (#77)
Browse files Browse the repository at this point in the history
- Introduced a dataclass for solver kwargs. This is makes it much more explicit what arguments are supported by DMRG and SHCI

- Could simplify the interface of BE by removing a couple of arguments
  • Loading branch information
mcocdawc authored Jan 10, 2025
1 parent cca67f9 commit a6cad31
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 266 deletions.
9 changes: 5 additions & 4 deletions example/molbe_dmrg_block2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pyscf import cc, fci, gto, scf

from quemb.molbe import BE, fragpart
from quemb.molbe.solver import DMRG_ArgsUser

# We'll consider the dissociation curve for a 1D chain of 8 H-atoms:
num_points = 3
Expand Down Expand Up @@ -52,7 +53,7 @@
# Next, run BE-DMRG with default parameters and maxM=100.
mybe.oneshot(
solver="block2", # or 'DMRG', 'DMRGSCF', 'DMRGCI'
DMRG_solver_kwargs=dict(
solver_args=DMRG_ArgsUser(
maxM=100, # Max fragment bond dimension
force_cleanup=True, # Remove all fragment DMRG tmpfiles
),
Expand Down Expand Up @@ -100,7 +101,7 @@
solver="block2", # or 'DMRG', 'DMRGSCF', 'DMRGCI'
max_iter=60, # Max number of sweeps
only_chem=True,
DMRG_solver_kwargs=dict(
solver_args=DMRG_ArgsUser(
startM=20, # Initial fragment bond dimension (1st sweep)
maxM=200, # Maximum fragment bond dimension
twodot_to_onedot=50, # Sweep num to switch from two- to one-dot algo.
Expand All @@ -113,7 +114,7 @@
)

# Or, alternatively, we can construct a full schedule by hand:
schedule = {
schedule: dict[str, list[int] | list[float]] = {
"scheduleSweeps": [0, 10, 20, 30, 40, 50], # Sweep indices
"scheduleMaxMs": [25, 50, 100, 200, 500, 500], # Sweep maxMs
"scheduleTols": [1e-5, 1e-5, 1e-6, 1e-6, 1e-8, 1e-8], # Sweep Davidson tolerances
Expand All @@ -124,7 +125,7 @@
mybe.optimize(
solver="block2",
only_chem=True,
DMRG_solver_kwargs=dict(
solver_args=DMRG_ArgsUser(
schedule_kwargs=schedule,
block_extra_keyword=["fiedler"],
force_cleanup=True,
Expand Down
69 changes: 25 additions & 44 deletions src/quemb/kbe/pbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
import numpy
from libdmet.basis_transform.eri_transform import get_emb_eri_fast_gdf
from numpy import array, floating
from pyscf import ao2mo, pbc
from pyscf.pbc import df, gto
from pyscf.pbc.df.df_jk import _ewald_exxdiv_for_G0
Expand All @@ -18,13 +19,13 @@
from quemb.molbe.be_parallel import be_func_parallel
from quemb.molbe.helper import get_eri, get_scfObj, get_veff
from quemb.molbe.opt import BEOPT
from quemb.molbe.solver import be_func
from quemb.molbe.solver import UserSolverArgs, be_func
from quemb.shared.external.optqn import (
get_be_error_jacobian as _ext_get_be_error_jacobian,
)
from quemb.shared.helper import copy_docstring
from quemb.shared.manage_scratch import WorkDir
from quemb.shared.typing import KwargDict, PathLike
from quemb.shared.typing import Matrix, PathLike


class BE(Mixin_k_Localize):
Expand Down Expand Up @@ -58,12 +59,8 @@ def __init__(
save: bool = False,
restart_file: PathLike = "storebe.pk",
save_file: PathLike = "storebe.pk",
hci_pt: bool = False,
nproc: int = 1,
ompnum: int = 4,
hci_cutoff: float = 0.001,
ci_coeff_cutoff: float | None = None,
select_cutoff: float | None = None,
iao_val_core: bool = True,
exxdiv: str = "ewald",
kpts: list[list[float]] | None = None,
Expand Down Expand Up @@ -159,12 +156,6 @@ def __init__(
self.nkpt = nkpts_
self.kpts = kpts

# HCI parameters
self.hci_cutoff = hci_cutoff
self.ci_coeff_cutoff = ci_coeff_cutoff
self.select_cutoff = select_cutoff
self.hci_pt = hci_pt

if not restart:
self.mo_energy = mf.mo_energy
mf.exxdiv = None
Expand Down Expand Up @@ -325,17 +316,17 @@ def __init__(

def optimize(
self,
solver="MP2",
method="QN",
only_chem=False,
use_cumulant=True,
conv_tol=1.0e-6,
relax_density=False,
J0=None,
nproc=1,
ompnum=4,
max_iter=500,
):
solver: str = "MP2",
method: str = "QN",
only_chem: bool = False,
use_cumulant: bool = True,
conv_tol: float = 1.0e-6,
relax_density: bool = False,
J0: Matrix[floating] | None = None,
nproc: int = 1,
ompnum: int = 4,
max_iter: int = 500,
) -> None:
"""BE optimization function
Interfaces BEOPT to perform bootstrap embedding optimization.
Expand Down Expand Up @@ -393,20 +384,17 @@ def optimize(
conv_tol=conv_tol,
only_chem=only_chem,
use_cumulant=use_cumulant,
hci_cutoff=self.hci_cutoff,
ci_coeff_cutoff=self.ci_coeff_cutoff,
relax_density=relax_density,
select_cutoff=self.select_cutoff,
solver=solver,
ebe_hf=self.ebe_hf,
)

if method == "QN":
# Prepare the initial Jacobian matrix
if only_chem:
J0 = [[0.0]]
J0 = array([[0.0]])
J0 = self.get_be_error_jacobian(jac_solver="HF")
J0 = [[J0[-1, -1]]]
J0 = J0[-1:, -1:]
else:
J0 = self.get_be_error_jacobian(jac_solver="HF")

Expand All @@ -429,10 +417,10 @@ def optimize(
raise ValueError("This optimization method for BE is not supported")

@copy_docstring(_ext_get_be_error_jacobian)
def get_be_error_jacobian(self, jac_solver="HF"):
def get_be_error_jacobian(self, jac_solver: str = "HF") -> Matrix[floating]:
return _ext_get_be_error_jacobian(self.Nfrag, self.Fobjs, jac_solver)

def print_ini(self):
def print_ini(self) -> None:
"""
Print initialization banner for the kBE calculation.
"""
Expand Down Expand Up @@ -683,7 +671,7 @@ def oneshot(
use_cumulant: bool = True,
nproc: int = 1,
ompnum: int = 4,
DMRG_solver_kwargs: KwargDict | None = None,
solver_args: UserSolverArgs | None = None,
) -> None:
"""
Perform a one-shot bootstrap embedding calculation.
Expand Down Expand Up @@ -711,14 +699,11 @@ def oneshot(
solver,
self.enuc,
nproc=ompnum,
use_cumulant=use_cumulant,
eeval=True,
return_vec=False,
hci_cutoff=self.hci_cutoff,
ci_coeff_cutoff=self.ci_coeff_cutoff,
select_cutoff=self.select_cutoff,
scratch_dir=self.scratch_dir,
DMRG_solver_kwargs=DMRG_solver_kwargs,
solver_args=solver_args,
use_cumulant=use_cumulant,
return_vec=False,
)
else:
rets = be_func_parallel(
Expand All @@ -727,15 +712,13 @@ def oneshot(
self.Nocc,
solver,
self.enuc,
eeval=True,
nproc=nproc,
ompnum=ompnum,
scratch_dir=self.scratch_dir,
solver_args=solver_args,
use_cumulant=use_cumulant,
eeval=True,
return_vec=False,
hci_cutoff=self.hci_cutoff,
ci_coeff_cutoff=self.ci_coeff_cutoff,
select_cutoff=self.select_cutoff,
scratch_dir=self.scratch_dir,
)

print("-----------------------------------------------------", flush=True)
Expand All @@ -759,8 +742,6 @@ def oneshot(
flush=True,
)

self.ebe_tot = rets[0]

def update_fock(self, heff=None):
"""
Update the Fock matrix for each fragment with the effective Hamiltonian.
Expand Down
66 changes: 35 additions & 31 deletions src/quemb/molbe/be_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
)
from quemb.molbe.pfrag import Frags
from quemb.molbe.solver import (
SHCI_ArgsUser,
UserSolverArgs,
_SHCI_Args,
make_rdm1_ccsd_t1,
make_rdm2_urlx,
solve_ccsd,
Expand Down Expand Up @@ -46,15 +49,13 @@ def run_solver(
eri_file: str = "eri_file.h5",
veff: Matrix[float64] | None = None,
veff0: Matrix[float64] | None = None,
hci_cutoff: float = 0.001,
ci_coeff_cutoff: float | None = None,
select_cutoff: float | None = None,
ompnum: int = 4,
writeh1: bool = False,
eeval: bool = True,
ret_vec: bool = False,
use_cumulant: bool = True,
relax_density: bool = False,
solver_args: UserSolverArgs | None = None,
):
"""
Run a quantum chemistry solver to compute the reduced density matrices.
Expand All @@ -67,9 +68,12 @@ def run_solver(
Initial guess for the density matrix.
scratch_dir :
The scratch dir root.
Fragment files will be stored in :code:`scratch_dir / dname`.
dname :
Directory name for storing intermediate files.
Fragment files will be stored in :code:`scratch_dir / dname`.
scratch_dir :
The scratch directory.
Fragment files will be stored in :code:`scratch_dir / dname`.
nao :
Number of atomic orbitals.
nocc :
Expand All @@ -95,12 +99,12 @@ def run_solver(
Number of OpenMP threads. Default is 4.
writeh1 :
If True, write the one-electron integrals to a file. Default is False.
use_cumulant :
If True, use the cumulant approximation for RDM2. Default is True.
eeval :
If True, evaluate the electronic energy. Default is True.
ret_vec :
If True, return vector with error and rdms. Default is True.
use_cumulant :
If True, use the cumulant approximation for RDM2. Default is True.
relax_density :
If True, use CCSD relaxed density. Default is False
Expand Down Expand Up @@ -144,19 +148,17 @@ def run_solver(
# pylint: disable-next=E0611
from pyscf import hci # noqa: PLC0415 # hci is an optional module

assert isinstance(solver_args, SHCI_ArgsUser)
SHCI_args = _SHCI_Args.from_user_input(solver_args)

nao, nmo = mf_.mo_coeff.shape
eri = ao2mo.kernel(mf_._eri, mf_.mo_coeff, aosym="s4", compact=False).reshape(
4 * ((nmo),)
)
ci_ = hci.SCI(mf_.mol)
if select_cutoff is None and ci_coeff_cutoff is None:
select_cutoff = hci_cutoff
ci_coeff_cutoff = hci_cutoff
elif select_cutoff is None or ci_coeff_cutoff is None:
raise ValueError

ci_.select_cutoff = select_cutoff
ci_.ci_coeff_cutoff = ci_coeff_cutoff
ci_.select_cutoff = SHCI_args.select_cutoff
ci_.ci_coeff_cutoff = SHCI_args.ci_coeff_cutoff

nelec = (nocc, nocc)
h1_ = multi_dot((mf_.mo_coeff.T, h1, mf_.mo_coeff))
Expand All @@ -174,6 +176,9 @@ def run_solver(

frag_scratch = WorkDir(scratch_dir / dname)

assert isinstance(solver_args, SHCI_ArgsUser)
SHCI_args = _SHCI_Args.from_user_input(solver_args)

nao, nmo = mf_.mo_coeff.shape
nelec = (nocc, nocc)
mch = shci.SHCISCF(mf_, nmo, nelec, orbpath=frag_scratch)
Expand All @@ -182,7 +187,7 @@ def run_solver(
mch.fcisolver.nPTiter = 0
mch.fcisolver.sweep_iter = [0]
mch.fcisolver.DoRDM = True
mch.fcisolver.sweep_epsilon = [hci_cutoff]
mch.fcisolver.sweep_epsilon = [solver_args.hci_cutoff]
mch.fcisolver.scratchDirectory = frag_scratch
if not writeh1:
mch.fcisolver.restart = True
Expand All @@ -193,6 +198,9 @@ def run_solver(
# pylint: disable-next=E0611
from pyscf import cornell_shci # noqa: PLC0415 # optional module

assert isinstance(solver_args, SHCI_ArgsUser)
SHCI_args = _SHCI_Args.from_user_input(solver_args)

frag_scratch = WorkDir(scratch_dir / dname)

nao, nmo = mf_.mo_coeff.shape
Expand All @@ -208,7 +216,7 @@ def run_solver(
ci.runtimedir = frag_scratch
ci.restart = True
ci.config["var_only"] = True
ci.config["eps_vars"] = [hci_cutoff]
ci.config["eps_vars"] = [solver_args.hci_cutoff]
ci.config["get_1rdm_csv"] = True
ci.config["get_2rdm_csv"] = True
ci.kernel(h1, eri, nmo, nelec)
Expand Down Expand Up @@ -382,16 +390,14 @@ def be_func_parallel(
solver: str,
enuc: float, # noqa: ARG001
scratch_dir: WorkDir,
only_chem: bool = False,
solver_args: UserSolverArgs | None,
nproc: int = 1,
ompnum: int = 4,
only_chem: bool = False,
relax_density: bool = False,
use_cumulant: bool = True,
eeval: bool = True,
return_vec: bool = True,
hci_cutoff: float = 0.001,
ci_coeff_cutoff: float | None = None,
select_cutoff: float | None = None,
eeval: bool = False,
return_vec: bool = False,
writeh1: bool = False,
):
"""
Expand All @@ -416,21 +422,21 @@ def be_func_parallel(
'FCI', 'HCI', 'SHCI', and 'SCI'.
enuc :
Nuclear component of the energy.
scratch_dir :
Scratch directory root
only_chem :
Whether to perform chemical potential optimization only.
Refer to bootstrap embedding literature. Defaults to False.
nproc :
Total number of processors assigned for the optimization. Defaults to 1.
When nproc > 1, Python multithreading is invoked.
ompnum :
If nproc > 1, sets the number of cores for OpenMP parallelization.
Defaults to 4.
use_cumulant :
Use cumulant energy expression. Defaults to True
only_chem :
Whether to perform chemical potential optimization only.
Refer to bootstrap embedding literature. Defaults to False.
eeval :
Whether to evaluate energies. Defaults to False.
scratch_dir :
Scratch directory root
use_cumulant :
Use cumulant energy expression. Defaults to True
return_vec :
Whether to return the error vector. Defaults to False.
writeh1 :
Expand Down Expand Up @@ -472,15 +478,13 @@ def be_func_parallel(
fobj.eri_file,
fobj.veff if not use_cumulant else None,
fobj.veff0,
hci_cutoff,
ci_coeff_cutoff,
select_cutoff,
ompnum,
writeh1,
eeval,
return_vec,
use_cumulant,
relax_density,
solver_args,
],
)

Expand Down
Loading

0 comments on commit a6cad31

Please sign in to comment.