diff --git a/src/aiida_koopmans/engine/aiida.py b/src/aiida_koopmans/engine/aiida.py index c2b7d70..d58988f 100644 --- a/src/aiida_koopmans/engine/aiida.py +++ b/src/aiida_koopmans/engine/aiida.py @@ -18,11 +18,11 @@ load_profile() class AiiDAEngine(Engine): - + """ Step data is a dictionary containing the following information: step_data = {calc.directory: {'workchain': workchain, 'remote_folder': remote_folder}} - and any other info we need for AiiDA. + and any other info we need for AiiDA. """ def __init__(self, *args, **kwargs): self.blocking = kwargs.pop('blocking', True) @@ -30,28 +30,26 @@ def __init__(self, *args, **kwargs): 'configuration': kwargs.pop('configuration', None), 'steps': {} } - self.skip_message = False - # here we add the logic to populate configuration by default # 1. we look for codes stored in AiiDA at localhost, e.g. pw-version@localhost, - # 2. we look for codes in the PATH, + # 2. we look for codes in the PATH, # 3. if we don't find the code in AiiDA db but in the PATH, we store it in AiiDA db. # 4. if we don't find the code in AiiDA db and in the PATH and not configuration is provided, we raise an error. if self.step_data['configuration'] is None: raise NotImplementedError("Configuration not provided") - + # 5. if no resource info in configuration, we try to look at PARA_PREFIX env var. - + super().__init__(*args, **kwargs) - + def run(self, step: Step): - + self.get_status(step) if step.prefix in ['wannier90_preproc', 'pw2wannier90']: self.set_status(step, Status.COMPLETED) return - + self.step_data['steps'][step.uid] = {} # maybe not needed builder, self.step_data = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails running = submit(builder) @@ -60,25 +58,25 @@ def run(self, step: Step): # The below will be passed to the context, so we will need to store also the instance of the submitted workchain, if in KoopmansWorkChain. self.step_data['steps'][step.uid] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder} - + self.set_status(step, Status.RUNNING) - - return - + + return + def load_step_data(self): try: with open('step_data.pkl', 'rb') as f: - # this will overwrite the step_data[configuration], - # i.e. if we change codes or res we will not see it if + # this will overwrite the step_data[configuration], + # i.e. if we change codes or res we will not see it if # the file already exists. - self.step_data = pickle.load(f) + self.step_data = pickle.load(f) except FileNotFoundError: pass - + def dump_step_data(self): with open('step_data.pkl', 'wb') as f: pickle.dump(self.step_data, f) - + def get_status(self, step: Step) -> Status: status = self.get_status_by_uid(step.uid) #print(f"Getting status for step {step.uid}: {status}") @@ -90,36 +88,35 @@ def get_status_by_uid(self, uid: str) -> Status: if uid not in self.step_data['steps']: self.step_data['steps'][uid] = {'status': Status.NOT_STARTED} return self.step_data['steps'][uid]['status'] - + def set_status(self, step: Step, status: Status): self.set_status_by_uid(step.uid, status) - #print(f"Step {step.uid} is {status}") def set_status_by_uid(self, uid: str, status: Status): self.step_data['steps'][uid]['status'] = status self.dump_step_data() - + def update_statuses(self) -> None: time.sleep(1) for uid in self.step_data['steps']: - + if not self.get_status_by_uid(uid) == Status.RUNNING: continue - + workchain = orm.load_node(self.step_data['steps'][uid]['workchain']) if workchain.is_finished_ok: self._step_completed_message_by_uid(uid) self.set_status_by_uid(uid, Status.COMPLETED) - + elif workchain.is_finished or workchain.is_excepted or workchain.is_killed: self._step_failed_message_by_uid(uid) self.set_status_by_uid(uid, Status.FAILED) - + return - + def load_results(self, step: Step) -> None: - + self.load_step_data() if step.prefix in ['wannier90_preproc', 'pw2wannier90']: @@ -141,10 +138,27 @@ def load_results(self, step: Step) -> None: step.calc = output.calc step.results = output.calc.results if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons'])) - + + self._step_completed_message(step) + + if step.ext_out in [".pro"]: + + pdos_dir = dump_pdos_outputs(step, workchain.outputs.retrieved) + prev_dir = step.directory + step.directory = pdos_dir + + try: + step.generate_dos() + except ValueError: + # ValueError: Must provide energies to create a GridDOSCollection without any DOS data. + pass + finally: + from aiida_koopmans.utils import delete_directory + delete_directory(pdos_dir.parent) + step.directory = prev_dir + self.dump_step_data() - - + def load_old_calculator(self, calc: Calc): raise NotImplementedError # load_old_calculator(calc) @@ -160,6 +174,7 @@ def get_pseudopotential(self, library: str, element: str): temp_file = pathlib.Path(dirpath) / (pseudo[0].base.attributes.all['element'] + '.upf') with pseudo[0].open(pseudo[0].base.attributes.all['element'] + '.upf', 'rb') as handle: temp_file.write_bytes(handle.read()) + pseudo_data = read_pseudo_file(temp_file) if not pseudo_data: diff --git a/src/aiida_koopmans/utils.py b/src/aiida_koopmans/utils.py index df1699e..523d515 100644 --- a/src/aiida_koopmans/utils.py +++ b/src/aiida_koopmans/utils.py @@ -1,4 +1,3 @@ - import shutil import pathlib import tempfile @@ -9,6 +8,7 @@ from aiida.common.exceptions import NotExistent from aiida.orm import Code, Computer from aiida_quantumespresso.calculations.pw import PwCalculation +from aiida_quantumespresso.calculations.projwfc import ProjwfcCalculation from aiida_wannier90.calculations.wannier90 import Wannier90Calculation from ase import Atoms @@ -22,8 +22,9 @@ LOCALHOST_NAME = "localhost-test" KCW_BLOCKED_KEYWORDS = [t[1] for t in KcwCalculation._blocked_keywords] PW_BLOCKED_KEYWORDS = [t[1] for t in PwCalculation._blocked_keywords] +PROJWFC_BLOCKED_KEYWORDS = [t[1] for t in ProjwfcCalculation._blocked_keywords] WANNIER90_BLOCKED_KEYWORDS = [t[1] for t in Wannier90Calculation._BLOCKED_PARAMETER_KEYS] -ALL_BLOCKED_KEYWORDS = KCW_BLOCKED_KEYWORDS + PW_BLOCKED_KEYWORDS + WANNIER90_BLOCKED_KEYWORDS + [f'celldm({i})' for i in range (1,7)] +ALL_BLOCKED_KEYWORDS = KCW_BLOCKED_KEYWORDS + PW_BLOCKED_KEYWORDS + WANNIER90_BLOCKED_KEYWORDS + PROJWFC_BLOCKED_KEYWORDS + [f'celldm({i})' for i in range (1,7)] def get_builder_from_ase(calculator, step_data=None): return mapping_calculators[calculator.ext_out](calculator, step_data) @@ -38,7 +39,7 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): aiida_inputs = step_data['configuration'] calc_params = pw_calculator._parameters - + structure = None parent_folder = None for step, val in step_data['steps'].items(): @@ -83,7 +84,7 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): builder.pw.metadata = aiida_inputs["metadata"] builder.kpoints = orm.KpointsData() - + if pw_overrides["CONTROL"]["calculation"] in ["scf", "nscf"]: builder.kpoints.set_kpoints_mesh(calc_params["kpts"]) elif pw_overrides["CONTROL"]["calculation"] == "bands": @@ -123,8 +124,8 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None) nscf = orm.load_node(val["workchain"]) if not nscf: raise ValueError("No nscf step found.") - - + + aiida_inputs = step_data['configuration'] codes = { @@ -160,11 +161,11 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None) t = np.where(k_linear==coords)[0] k_labels.append([t[0],label]) k_coords.append(special_k[label].tolist()) - + kpoints_path = orm.KpointsData() kpoints_path.set_kpoints(k_path,labels=k_labels,cartesian=False) builder.kpoint_path = kpoints_path - + # Start parameters and projections setting using the Wannier90Calculator data. params = builder.wannier90.wannier90.parameters.get_dict() @@ -238,10 +239,61 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None) return builder, step_data + +def get_projwfc_builder_from_ase(projwfc_calculator, step_data=None): + from aiida import load_profile, orm + from aiida_quantumespresso.calculations.projwfc import ProjwfcCalculation + + load_profile() + + """ + Convert a `ProjwfcCalculator` into an AiiDA `ProjwfcCalculation + """ + + aiida_inputs = step_data["configuration"] + calc_params = projwfc_calculator._parameters + + # TODO: This is not needed, if we can just pass `orm.Dict(calc_params)` to the builder + from koopmans.settings import ProjwfcSettingsDict + + projwfc_parameters = {} + projwfcsettingsdict = ProjwfcSettingsDict() + projwfc_keys = ( + projwfcsettingsdict.valid + + list(projwfcsettingsdict.defaults.keys()) + + projwfcsettingsdict.are_paths + ) + for k in projwfc_keys: + if k in calc_params.keys() and k not in ALL_BLOCKED_KEYWORDS: + projwfc_parameters[k] = calc_params[k] + + projwfc_parameters['filpdos'] = 'aiida' + + builder = ProjwfcCalculation.get_builder() + builder.code = orm.load_code(aiida_inputs["projwfc_code"]) + builder.parameters = orm.Dict({"PROJWFC": projwfc_parameters}) + builder.metadata = aiida_inputs["metadata"] + + parent_calculators = [ + f[0].uid for f in projwfc_calculator.linked_files.values() if f[0] is not None + ] + + if len(set(parent_calculators)) > 1: + raise ValueError("More than one parent calculator found.") + elif len(set(parent_calculators)) == 1: + if "remote_folder" in step_data["steps"][parent_calculators[0]]: + builder.parent_folder = orm.load_node( + step_data["steps"][parent_calculators[0]]["remote_folder"] + ) + + return builder + + ## Here we have the mapping for the calculators initialization. used in the `aiida_calculate_trigger`. mapping_calculators = { ".pwo" : get_PwBaseWorkChain_from_ase, ".wout": get_Wannier90BandsWorkChain_builder_from_ase, + ".pro": get_projwfc_builder_from_ase, #".w2ko": from_wann2kc_to_KcwCalculation, #".kso": from_kcwscreen_to_KcwCalculation, #".kho": from_kcwham_to_KcwCalculation, @@ -250,13 +302,13 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None) # read the output file, mimicking the read_results method of ase-koopmans: https://github.com/elinscott/ase_koopmans/blob/master/ase/calculators/espresso/_espresso.py def read_output_file(calculator, retrieved, inner_remote_folder=None): """ - Read the output file of a calculator using ASE io.read() method but parsing the AiiDA outputs. + Read the output file of a calculator using ASE io.read() method but parsing the AiiDA outputs. NB: calculator (ASE) should contain the related AiiDA workchain as attribute. """ - #if inner_remote_folder: + # if inner_remote_folder: # retrieved = inner_remote_folder - #else: - #retrieved = workchain.outputs.retrieved + # else: + # retrieved = workchain.outputs.retrieved with tempfile.TemporaryDirectory() as dirpath: # Open the output file from the AiiDA storage and copy content to the temporary file for filename in retrieved.base.repository.list_object_names(): @@ -267,4 +319,34 @@ def read_output_file(calculator, retrieved, inner_remote_folder=None): with retrieved.open(filename, 'rb') as handle: temp_file.write_bytes(handle.read()) output = io.read(temp_file) - return output \ No newline at end of file + return output + + +def dump_pdos_outputs(calculator, retrieved): + """ + Dump the `pdos` output files of a projwfc.x calculation run via AiiDA to a temporary directory which is returned. + """ + + output_dir = calculator.directory / pathlib.Path(tempfile.mkdtemp()).parts[-1] + output_dir.mkdir(exist_ok=True, parents=True) + + for filename in retrieved.base.repository.list_object_names(): + if ".pdos" in filename: + # Create the file with the desired name + output_file = pathlib.Path(output_dir) / ( + f"{calculator.parameters.filpdos}." + filename.replace("aiida.", "") + ) + with retrieved.open(filename, "rb") as handle: + output_file.write_bytes(handle.read()) + + return output_dir + + +def delete_directory(dir_path): + dir_path = pathlib.Path(dir_path) + for child in dir_path.iterdir(): + if child.is_dir(): + delete_directory(child) + else: + child.unlink() + dir_path.rmdir()