From db747840810c7aa1497e14057b4a37a9e3bb2c53 Mon Sep 17 00:00:00 2001 From: mikibonacci Date: Thu, 5 Dec 2024 15:52:04 +0000 Subject: [PATCH] First partially working implementation of read, write. --- src/aiida_koopmans/engine/aiida.py | 52 +++++++++++++++++++++++++++++- src/aiida_koopmans/utils.py | 2 +- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/aiida_koopmans/engine/aiida.py b/src/aiida_koopmans/engine/aiida.py index c2cd1ce..ac8e98a 100644 --- a/src/aiida_koopmans/engine/aiida.py +++ b/src/aiida_koopmans/engine/aiida.py @@ -3,6 +3,10 @@ from koopmans.calculators import Calc from koopmans.pseudopotentials import read_pseudo_file from koopmans.status import Status +from koopmans.files import FilePointer +from koopmans.processes import Process + +from typing import Generator, List from aiida.engine import run_get_node, submit @@ -46,11 +50,19 @@ def __init__(self, *args, **kwargs): def run(self, step: Step): self.get_status(step) + + if isinstance(step, Process): + step.run() + self.set_status(step, Status.COMPLETED) + self._step_completed_message(step) + return + if step.prefix in ['wannier90_preproc', 'pw2wannier90']: self.set_status(step, Status.COMPLETED) return self.step_data['steps'][step.uid] = {} # maybe not needed + builder, self.step_data = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails running = submit(builder) print(f"Running workchain {running.pk} for step {step.uid}") @@ -98,7 +110,7 @@ def set_status_by_uid(self, uid: str, status: Status): def update_statuses(self) -> None: - time.sleep(1) + #time.sleep(1) for uid in self.step_data['steps']: if not self.get_status_by_uid(uid) == Status.RUNNING: @@ -119,6 +131,10 @@ def load_results(self, step: Step) -> None: self.load_step_data() + if isinstance(step, Process): + print(self.step_data['steps'][step.uid]) + return + if step.prefix in ['wannier90_preproc', 'pw2wannier90']: self.set_status(step, Status.COMPLETED) return @@ -185,3 +201,37 @@ def get_pseudopotential(self, library: str, element: str): self.step_data['pseudo_family'] = library return pseudo_data + + def read(self, file: FilePointer, binary=False) -> str | bytes: + workchain = orm.load_node(self.step_data['steps'][file[0].uid]['workchain']) + filename = str(file[1]).replace(file[0].prefix, 'aiida') + if 'wannier90' in file[0].prefix: + content = workchain.outputs.wannier90.retrieved.get_object_content(filename, mode='r') + else: + content = workchain.outputs.retrieved.get_object_content(filename, mode='r') + # maybe unnecessary content post-processing + '''content = content.split("\n") + for line in range(len(content)): + content[line] += "\n" + ''' + return content + + def write(self, content: str | bytes, file: FilePointer) -> None: + if 'inputs.pkl' in str(file[1]): + return + if isinstance(file[0], Process): + filename = file[0].inputs.dst_file + else: + filename = str(file[1]).replace(file[0].prefix, 'aiida') + + if isinstance(content, bytes): + # skip the dumping of the *out.pkl file, we don't want as SinglefileData + return + singlefile = orm.SinglefileData.from_string(content, filename) + singlefile.store() + self.step_data['steps'][file[0].uid][str(filename)] = singlefile.pk + return singlefile + + def glob(self, pattern: FilePointer, recursive=False) -> Generator[FilePointer, None, None]: + raise NotImplementedError() + diff --git a/src/aiida_koopmans/utils.py b/src/aiida_koopmans/utils.py index 499b5e4..007af83 100644 --- a/src/aiida_koopmans/utils.py +++ b/src/aiida_koopmans/utils.py @@ -65,7 +65,7 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): pw_overrides["CONTROL"][k] = calc_params[k] for k in pw_keys['system']: - if k in calc_params.keys() and k not in [ALL_BLOCKED_KEYWORDS, 'tot_magnetization']: + if k in calc_params.keys() and k not in [ALL_BLOCKED_KEYWORDS]: pw_overrides["SYSTEM"][k] = calc_params[k] for k in pw_keys['electrons']: