Skip to content

Commit

Permalink
do explicit copies
Browse files Browse the repository at this point in the history
  • Loading branch information
mcocdawc committed Jan 13, 2025
1 parent 87b62d3 commit d2e1c75
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 41 deletions.
35 changes: 21 additions & 14 deletions src/quemb/kbe/pbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,12 +545,12 @@ def initialize(self, compute_hf: bool, restart: bool = False) -> None:
result = pool_.apply_async(
eritransform_parallel,
[
self.mf.cell.a,
self.mf.cell.atom,
self.mf.cell.basis,
self.kpts,
self.Fobjs[frg].TA,
self.cderi,
self.mf.cell.a.copy(),
self.mf.cell.atom.copy(),
self.mf.cell.basis.copy(),
self.kpts.copy(),
self.Fobjs[frg].TA.copy(),
self.cderi.copy(),
],
)
results.append(result)
Expand All @@ -568,13 +568,13 @@ def initialize(self, compute_hf: bool, restart: bool = False) -> None:
result = pool_.apply_async(
parallel_fock_wrapper,
[
self.Fobjs[frg].dname,
self.Fobjs[frg].nao,
self.hf_dm,
self.S,
self.Fobjs[frg].TA,
self.hf_veff,
self.eri_file,
self.Fobjs[frg].dname.copy(),
self.Fobjs[frg].nao.copy(),
self.hf_dm.copy(),
self.S.copy(),
self.Fobjs[frg].TA.copy(),
self.hf_veff.copy(),
self.eri_file.copy(),
],
)
results.append(result)
Expand Down Expand Up @@ -608,7 +608,14 @@ def initialize(self, compute_hf: bool, restart: bool = False) -> None:
h1 = self.Fobjs[frg].fock + self.Fobjs[frg].heff
result = pool_.apply_async(
parallel_scf_wrapper,
[dname, nao, nocc, h1, self.Fobjs[frg].dm_init, self.eri_file],
[
dname.copy(),
nao.copy(),
nocc.copy(),
h1.copy(),
self.Fobjs[frg].dm_init.copy(),
self.eri_file.copy(),
],
)
results.append(result)
mo_coeffs = [result.get() for result in results]
Expand Down
54 changes: 27 additions & 27 deletions src/quemb/molbe/be_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,25 +465,25 @@ def be_func_parallel(
[
fobj.fock + fobj.heff,
fobj.dm0.copy(),
scratch_dir,
fobj.dname,
fobj.nao,
fobj.nsocc,
fobj.nfsites,
fobj.efac,
fobj.TA,
fobj.h1,
solver,
fobj.eri_file,
fobj.veff if not use_cumulant else None,
fobj.veff0,
ompnum,
writeh1,
eeval,
return_vec,
use_cumulant,
relax_density,
solver_args,
scratch_dir.copy(),
fobj.dname.copy(),
fobj.nao.copy(),
fobj.nsocc.copy(),
fobj.nfsites.copy(),
fobj.efac.copy(),
fobj.TA.copy(),
fobj.h1.copy(),
solver.copy(),
fobj.eri_file.copy(),
fobj.veff.copy() if not use_cumulant else None,
fobj.veff0.copy(),
ompnum.copy(),
writeh1.copy(),
eeval.copy(),
return_vec.copy(),
use_cumulant.copy(),
relax_density.copy(),
solver_args.copy(),
],
)

Expand Down Expand Up @@ -588,14 +588,14 @@ def be_func_parallel_u(
result = pool_.apply_async(
run_solver_u,
[
fobj_a,
fobj_b,
solver,
enuc,
hf_veff,
relax_density,
frozen,
use_cumulant,
fobj_a.copy(),
fobj_b.copy(),
solver.copy(),
enuc.copy(),
hf_veff.copy(),
relax_density.copy(),
frozen.copy(),
use_cumulant.copy(),
True,
],
)
Expand Down

0 comments on commit d2e1c75

Please sign in to comment.