Skip to content

Commit

Permalink
First partially working implementation of read, write.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikibonacci committed Dec 5, 2024
1 parent 3f30402 commit db74784
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
52 changes: 51 additions & 1 deletion src/aiida_koopmans/engine/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from koopmans.calculators import Calc
from koopmans.pseudopotentials import read_pseudo_file
from koopmans.status import Status
from koopmans.files import FilePointer
from koopmans.processes import Process

from typing import Generator, List

from aiida.engine import run_get_node, submit

Expand Down Expand Up @@ -46,11 +50,19 @@ def __init__(self, *args, **kwargs):
def run(self, step: Step):

self.get_status(step)

if isinstance(step, Process):
step.run()
self.set_status(step, Status.COMPLETED)
self._step_completed_message(step)
return

if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
self.set_status(step, Status.COMPLETED)
return

self.step_data['steps'][step.uid] = {} # maybe not needed

builder, self.step_data = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails
running = submit(builder)
print(f"Running workchain {running.pk} for step {step.uid}")
Expand Down Expand Up @@ -98,7 +110,7 @@ def set_status_by_uid(self, uid: str, status: Status):

def update_statuses(self) -> None:

time.sleep(1)
#time.sleep(1)
for uid in self.step_data['steps']:

if not self.get_status_by_uid(uid) == Status.RUNNING:
Expand All @@ -119,6 +131,10 @@ def load_results(self, step: Step) -> None:

self.load_step_data()

if isinstance(step, Process):
print(self.step_data['steps'][step.uid])
return

if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
self.set_status(step, Status.COMPLETED)
return
Expand Down Expand Up @@ -185,3 +201,37 @@ def get_pseudopotential(self, library: str, element: str):
self.step_data['pseudo_family'] = library

return pseudo_data

def read(self, file: FilePointer, binary=False) -> str | bytes:
workchain = orm.load_node(self.step_data['steps'][file[0].uid]['workchain'])
filename = str(file[1]).replace(file[0].prefix, 'aiida')
if 'wannier90' in file[0].prefix:
content = workchain.outputs.wannier90.retrieved.get_object_content(filename, mode='r')
else:
content = workchain.outputs.retrieved.get_object_content(filename, mode='r')
# maybe unnecessary content post-processing
'''content = content.split("\n")
for line in range(len(content)):
content[line] += "\n"
'''
return content

def write(self, content: str | bytes, file: FilePointer) -> None:
if 'inputs.pkl' in str(file[1]):
return
if isinstance(file[0], Process):
filename = file[0].inputs.dst_file
else:
filename = str(file[1]).replace(file[0].prefix, 'aiida')

if isinstance(content, bytes):
# skip the dumping of the *out.pkl file, we don't want as SinglefileData
return
singlefile = orm.SinglefileData.from_string(content, filename)
singlefile.store()
self.step_data['steps'][file[0].uid][str(filename)] = singlefile.pk
return singlefile

def glob(self, pattern: FilePointer, recursive=False) -> Generator[FilePointer, None, None]:
raise NotImplementedError()

2 changes: 1 addition & 1 deletion src/aiida_koopmans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None):
pw_overrides["CONTROL"][k] = calc_params[k]

for k in pw_keys['system']:
if k in calc_params.keys() and k not in [ALL_BLOCKED_KEYWORDS, 'tot_magnetization']:
if k in calc_params.keys() and k not in [ALL_BLOCKED_KEYWORDS]:
pw_overrides["SYSTEM"][k] = calc_params[k]

for k in pw_keys['electrons']:
Expand Down

0 comments on commit db74784

Please sign in to comment.