-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e7f9344
commit 0cb766d
Showing
7 changed files
with
393 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from ._cp2k import CP2KReference # noqa: F401 | ||
from ._emt import EMTReference # noqa: F401 | ||
from ._nwchem import NWChemReference # noqa: F401 | ||
from ._pyscf import PySCFReference # noqa: F401 | ||
from .base import BaseReference # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
from __future__ import annotations # necessary for type-guarding class methods | ||
|
||
import logging | ||
|
||
import numpy as np | ||
import parsl | ||
from ase import Atoms | ||
from ase.data import atomic_numbers | ||
from parsl.app.app import bash_app, join_app, python_app | ||
from parsl.data_provider.files import File | ||
|
||
import psiflow | ||
from psiflow.data import FlowAtoms, NullState | ||
from psiflow.reference.base import BaseReference | ||
from psiflow.utils import copy_app_future | ||
|
||
logger = logging.getLogger(__name__) # logging per module | ||
|
||
|
||
def atoms_to_molecule(ase_atoms, basis, spin): | ||
from pyscf import gto | ||
from pyscf.pbc import gto as pbcgto | ||
|
||
atom_symbols = ase_atoms.get_chemical_symbols() | ||
atom_coords = ase_atoms.get_positions() | ||
atom_spec = [ | ||
(symbol, tuple(coord)) for symbol, coord in zip(atom_symbols, atom_coords) | ||
] | ||
|
||
if ase_atoms.get_pbc().any(): # Periodic boundary conditions | ||
cell_params = ase_atoms.get_cell() | ||
pyscf_cell = pbcgto.Cell() | ||
pyscf_cell.atom = atom_spec | ||
pyscf_cell.basis = basis | ||
pyscf_cell.spin = spin | ||
pyscf_cell.a = cell_params | ||
pyscf_cell.build() | ||
return pyscf_cell | ||
else: # Non-periodic (molecular) | ||
pyscf_mol = gto.Mole() | ||
pyscf_mol.atom = atom_spec | ||
pyscf_mol.basis = basis | ||
pyscf_mol.spin = spin | ||
pyscf_mol.verbose = 5 | ||
pyscf_mol.build() | ||
return pyscf_mol | ||
|
||
|
||
def serialize_atoms(atoms): | ||
atoms_str = "dict(symbols={}, positions={}, cell={}, pbc={})".format( | ||
atoms.get_chemical_symbols(), | ||
atoms.get_positions().tolist(), | ||
atoms.get_cell().tolist(), | ||
atoms.get_pbc().tolist(), | ||
) | ||
return atoms_str | ||
# atoms_dict = { | ||
# 'symbols': atoms.get_chemical_symbols(), | ||
# 'positions': atoms.get_positions().tolist(), # Convert numpy array to list | ||
# 'cell': atoms.get_cell().tolist(), | ||
# 'pbc': atoms.get_pbc().tolist(), | ||
# } | ||
# return repr(atoms_dict) | ||
|
||
|
||
def deserialize_atoms(atoms_dict): | ||
return Atoms( | ||
symbols=atoms_dict["symbols"], | ||
positions=np.array(atoms_dict["positions"]), # Convert list back to numpy array | ||
cell=np.array(atoms_dict["cell"]), | ||
pbc=atoms_dict["pbc"], | ||
) | ||
|
||
|
||
def generate_script(state, routine, basis, spin): | ||
# print 'energy' and 'forces' variables | ||
routine = routine.strip() | ||
routine += """ | ||
print('total energy = {}'.format(energy * Ha)) | ||
print('total forces = ') | ||
for force in forces: | ||
print(*(force * Ha / Bohr)) | ||
""" | ||
lines = routine.split("\n") # indent entire routine | ||
for i in range(len(lines)): | ||
lines[i] = " " + lines[i] | ||
routine = "\n".join(lines) | ||
|
||
script = """ | ||
from ase.units import Ha, Bohr | ||
from psiflow.reference._pyscf import deserialize_atoms, atoms_to_molecule | ||
def main(molecule): | ||
{} | ||
""".format( | ||
routine | ||
) | ||
script += """ | ||
if __name__ == '__main__': | ||
atoms_dict = {} | ||
atoms = deserialize_atoms(atoms_dict) | ||
molecule = atoms_to_molecule( | ||
atoms, | ||
basis='{}', | ||
spin={}, | ||
) | ||
main(molecule) | ||
""".format( | ||
serialize_atoms(state).strip(), | ||
basis, | ||
spin, | ||
) | ||
return script | ||
|
||
|
||
def parse_energy_forces(stdout): | ||
energy = None | ||
forces_str = None | ||
lines = stdout.split("\n") | ||
for i, line in enumerate(lines[::-1]): # start from back! | ||
if energy is None and "total energy = " in line: | ||
energy = float(line.split("total energy = ")[1]) | ||
if forces_str is None and "total forces =" in line: | ||
forces_str = "\n".join(lines[-i:]) | ||
assert energy is not None | ||
assert forces_str is not None | ||
rows = forces_str.strip().split("\n") | ||
nrows = len(rows) | ||
ncols = len(rows[0].split()) | ||
assert ncols == 3 | ||
forces = np.fromstring("\n".join(rows), sep=" ", dtype=float) | ||
return energy, np.reshape(forces, (nrows, ncols)) | ||
|
||
|
||
def pyscf_singlepoint_pre( | ||
atoms: FlowAtoms, | ||
omp_num_threads: int, | ||
stdout: str = "", | ||
stderr: str = "", | ||
walltime: int = 0, | ||
**parameters, | ||
) -> str: | ||
from psiflow.reference._pyscf import generate_script | ||
|
||
script = generate_script(atoms, **parameters) | ||
command_tmp = 'mytmpdir=$(mktemp -d 2>/dev/null || mktemp -d -t "mytmpdir");' | ||
command_cd = "cd $mytmpdir;" | ||
command_write = 'echo "{}" > generated.py;'.format(script) | ||
command_list = [ | ||
command_tmp, | ||
command_cd, | ||
command_write, | ||
"export OMP_NUM_THREADS={};".format(omp_num_threads), | ||
"timeout -s 9 {}s python generated.py || true".format(max(walltime - 2, 0)), | ||
] | ||
return " ".join(command_list) | ||
|
||
|
||
def pyscf_singlepoint_post( | ||
atoms: FlowAtoms, | ||
inputs: list[File] = [], | ||
) -> FlowAtoms: | ||
from psiflow.reference._pyscf import parse_energy_forces | ||
|
||
atoms.reference_stdout = inputs[0] | ||
atoms.reference_stderr = inputs[1] | ||
with open(atoms.reference_stdout, "r") as f: | ||
content = f.read() | ||
try: | ||
energy, forces = parse_energy_forces(content) | ||
assert forces.shape == atoms.positions.shape | ||
atoms.info["energy"] = energy | ||
atoms.arrays["forces"] = forces | ||
atoms.reference_status = True | ||
except Exception: | ||
atoms.reference_status = False | ||
return atoms | ||
|
||
|
||
class PySCFReference(BaseReference): | ||
required_files = [] | ||
|
||
def __init__(self, routine, basis, spin): | ||
assert ( | ||
"energy = " in routine | ||
), "define the total energy (in Ha) in your pyscf routine" | ||
assert ( | ||
"forces = " in routine | ||
), "define the forces (in Ha/Bohr) in your pyscf routine" | ||
assert "pyscf" in routine, "put all necessary imports inside the routine!" | ||
self.routine = routine | ||
self.basis = basis | ||
self.spin = spin | ||
super().__init__() | ||
|
||
def get_single_atom_references(self, element): | ||
number = atomic_numbers[element] | ||
references = [] | ||
for spin in range(15): | ||
config = {"spin": spin} | ||
mult = spin + 1 | ||
if number % 2 == 0 and mult % 2 == 0: | ||
continue | ||
if mult == 1 and number % 2 == 1: | ||
continue | ||
if mult - 1 > number: | ||
continue | ||
parameters = self.parameters | ||
parameters["spin"] = spin | ||
reference = self.__class__(**parameters) | ||
references.append((config, reference)) | ||
return references | ||
|
||
@property | ||
def parameters(self): | ||
return { | ||
"routine": self.routine, | ||
"basis": self.basis, | ||
"spin": self.spin, | ||
} | ||
|
||
@classmethod | ||
def create_apps(cls): | ||
context = psiflow.context() | ||
definition = context[cls] | ||
label = definition.name() | ||
ncores = definition.cores_per_worker | ||
walltime = definition.max_walltime | ||
|
||
singlepoint_pre = bash_app( | ||
pyscf_singlepoint_pre, | ||
executors=[label], | ||
) | ||
singlepoint_post = python_app( | ||
pyscf_singlepoint_post, | ||
executors=["default_threads"], | ||
) | ||
|
||
@join_app | ||
def singlepoint_wrapped(atoms, parameters, file_names, inputs=[]): | ||
assert len(file_names) == 0 | ||
if atoms == NullState: | ||
return copy_app_future(NullState) | ||
else: | ||
pre = singlepoint_pre( | ||
atoms, | ||
omp_num_threads=ncores, | ||
stdout=parsl.AUTO_LOGNAME, | ||
stderr=parsl.AUTO_LOGNAME, | ||
walltime=60 * walltime, # killed after walltime - 10s | ||
**parameters, | ||
) | ||
return singlepoint_post( | ||
atoms=atoms, | ||
inputs=[pre.stdout, pre.stderr, pre], # wait for bash app | ||
) | ||
|
||
context.register_app(cls, "evaluate_single", singlepoint_wrapped) | ||
super(PySCFReference, cls).create_apps() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.