diff --git a/src/aiida_koopmans/helpers.py b/src/aiida_koopmans/helpers.py index a0c00cd..c2154f6 100644 --- a/src/aiida_koopmans/helpers.py +++ b/src/aiida_koopmans/helpers.py @@ -24,6 +24,10 @@ from aiida_koopmans.calculations.kcw import KcwCalculation +""" +ASE calculator MUST have `wchain` attribute (the related AiiDA WorkChain) to be able to use these functions! +""" + LOCALHOST_NAME = "localhost-test" KCW_BLOCKED_KEYWORDS = [t[1] for t in KcwCalculation._blocked_keywords] PW_BLOCKED_KEYWORDS = [t[1] for t in PwCalculation._blocked_keywords] @@ -571,7 +575,7 @@ def aiida_pre_calculate_trigger(_pre_calculate): @functools.wraps(_pre_calculate) def wrapper_aiida_trigger(self): if self.parameters.mode == "ase": - return self._pre_calculate() + return _pre_calculate(self,) else: pass return wrapper_aiida_trigger @@ -581,7 +585,7 @@ def aiida_calculate_trigger(_calculate): @functools.wraps(_calculate) def wrapper_aiida_trigger(self): if self.parameters.mode == "ase": - return self._calculate() + return _calculate(self,) else: builder = mapping_calculators[self.ext_out](self) from aiida.engine import run_get_node, submit @@ -595,7 +599,7 @@ def aiida_post_calculate_trigger(_post_calculate): @functools.wraps(_post_calculate) def wrapper_aiida_trigger(self): if self.parameters.mode == "ase": - return self._post_calculate() + return _post_calculate(self,) else: pass return wrapper_aiida_trigger @@ -606,16 +610,18 @@ def aiida_read_results_trigger(read_results): @functools.wraps(read_results) def wrapper_aiida_trigger(self): if self.parameters.mode == "ase": - return self.read_results() + return read_results(self,) else: if self.ext_out == ".wout": output = read_output_file(self, self.wchain.outputs.wannier90.retrieved) elif self.ext_out == ".pwo": output = read_output_file(self) - self.calc = output.calc - self.results = output.calc.results if hasattr(output.calc, 'kpts'): self.kpts = output.calc.kpts + + self.calc = output.calc + self.results = output.calc.results + return wrapper_aiida_trigger def aiida_link_trigger(link): @@ -623,7 +629,7 @@ def aiida_link_trigger(link): @functools.wraps(link) def wrapper_aiida_trigger(self,src_calc, src_path, dest_calc, dest_path): if self.parameters.mode == "ase": - return self.link(src_calc, src_path, dest_calc, dest_path) + return link(self, src_calc, src_path, dest_calc, dest_path) elif src_calc: # if pseudo linking, src_calc = None dest_calc.parent_folder = src_calc.wchain.outputs.remote_folder return wrapper_aiida_trigger \ No newline at end of file