Skip to content

Commit

Permalink
Optional support for aiida-atomistic
Browse files Browse the repository at this point in the history
We do a try/except/else import.

Also, support in PwRelaxWorkChain.
  • Loading branch information
mikibonacci committed Dec 5, 2024

Verified

This commit was signed with the committer’s verified signature.
stgraber Stéphane Graber
1 parent 7a8254d commit 22d5ecf
Showing 5 changed files with 61 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/aiida_quantumespresso/calculations/__init__.py
Original file line number Diff line number Diff line change
@@ -128,7 +128,7 @@ def define(cls, spec):
spec.input('metadata.options.input_filename', valid_type=str, default=cls._DEFAULT_INPUT_FILE)
spec.input('metadata.options.output_filename', valid_type=str, default=cls._DEFAULT_OUTPUT_FILE)
spec.input('metadata.options.withmpi', valid_type=bool, default=True) # Override default withmpi=False
spec.input('structure', valid_type=(structures_classes),
spec.input('structure', valid_type=structures_classes,
help='The input structure.')
spec.input('parameters', valid_type=orm.Dict,
help='The input parameters that are to be used to construct the input file.')
10 changes: 8 additions & 2 deletions src/aiida_quantumespresso/calculations/pw.py
Original file line number Diff line number Diff line change
@@ -4,13 +4,19 @@
import warnings

from aiida import orm
from aiida.common import exceptions
from aiida.common.lang import classproperty
from aiida.orm import StructureData as LegacyStructureData
from aiida.plugins import factories

from aiida_quantumespresso.calculations import BasePwCpInputGenerator

StructureData = factories.DataFactory('atomistic.structure')
try:
StructureData = factories.DataFactory('atomistic.structure')
except exceptions.MissingEntryPointError:
structures_classes = (LegacyStructureData,)
else:
structures_classes = (LegacyStructureData, StructureData)


class PwCalculation(BasePwCpInputGenerator):
@@ -72,7 +78,7 @@ def define(cls, spec):
'will not fail if the XML file is missing in the retrieved folder.')
spec.input('kpoints', valid_type=orm.KpointsData,
help='kpoint mesh or kpoint path')
spec.input('hubbard_file', valid_type=(StructureData, LegacyStructureData), required=False,
spec.input('hubbard_file', valid_type=structures_classes, required=False,
help='SinglefileData node containing the output Hubbard parameters from a HpCalculation')
spec.inputs.validator = cls.validate_inputs

32 changes: 23 additions & 9 deletions src/aiida_quantumespresso/parsers/parse_raw/base.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,15 @@
"""A basic parser for the common format of QE."""
import re

from aiida.plugins import DataFactory
from aiida.orm import StructureData as LegacyStructureData
from aiida.plugins import DataFactory

from aiida_atomistic import StructureDataMutable, StructureData
try:
StructureData = DataFactory('atomistic.structure')
except exceptions.MissingEntryPointError:
structures_classes = (LegacyStructureData,)
else:
structures_classes = (LegacyStructureData, StructureData)

__all__ = ('convert_qe_time_to_sec', 'convert_qe_to_aiida_structure', 'convert_qe_to_kpoints')

@@ -53,14 +58,23 @@ def convert_qe_to_aiida_structure(output_dict, input_structure=None):
# Without an input structure, try to recreate the structure from the output
if not input_structure:

structure = StructureDataMutable()
structure.set_cell=cell_dict['lattice_vectors']
if isinstance(input_structure, LegacyStructureData):
structure = LegacyStructureData()
structure.set_cell=cell_dict['lattice_vectors']

for kind_name, position in output_dict['atoms']:
symbol = re.sub(r'\d+', '', kind_name)
structure.append_atom(position=position, symbols=symbol, name=kind_name)

else:
structure = StructureDataMutable()
structure.set_cell=cell_dict['lattice_vectors']

for kind_name, position in output_dict['atoms']:
symbol = re.sub(r'\d+', '', kind_name)
structure.append_atom(positions=position, symbols=symbol, kinds=kind_name)

for kind_name, position in output_dict['atoms']:
symbol = re.sub(r'\d+', '', kind_name)
structure.add_atom(position=position, symbols=symbol, name=kind_name)

return StructureData.from_mutable(structure)
structure = StructureData.from_mutable(structure)

else:

30 changes: 18 additions & 12 deletions src/aiida_quantumespresso/utils/magnetic.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
# -*- coding: utf-8 -*-
"""Utility class for handling the :class:`aiida_quantumespresso.data.hubbard_structure.HubbardStructureData`."""
# pylint: disable=no-name-in-module
from itertools import product
import os
from typing import Tuple, Union

from aiida import orm
from aiida.common.exceptions import MissingEntryPointError
from aiida.engine import calcfunction
from aiida.orm import StructureData as LegacyStructureData
from aiida.plugins import DataFactory
import numpy as np

StructureData = DataFactory('atomistic.structure')
try:
StructureData = DataFactory('atomistic.structure')
except MissingEntryPointError:
structures_classes = (LegacyStructureData,)
else:
structures_classes = (LegacyStructureData, StructureData)


class MagneticUtils:
class MagneticUtils: # pylint: disable=too-few-public-methods
"""Class to manage the magnetic structure of the atomistic `LegacyStructureData`.
It contains methods to manipulate the magnetic structure in such a way to produce the correct input for QuantumESPRESSO calculations.
It contains methods to manipulate the magne tic structure in such a way to produce
the correct input for QuantumESPRESSO calculations.
"""

def __init__(
self,
structure: StructureData,
structure: structures_classes,
):
"""Set a the `StructureData` to manipulate."""
if isinstance(structure, StructureData):
@@ -33,6 +37,7 @@ def __init__(

def generate_magnetic_namelist(self, parameters):
"""Generate the magnetic namelist for Quantum ESPRESSO.
:param parameters: dictionary of inputs for the Quantum ESPRESSO calculation.
"""
if 'nspin' not in parameters['SYSTEM'] and 'noncolin' not in parameters['SYSTEM']:
@@ -54,15 +59,16 @@ def generate_magnetic_namelist(self, parameters):
)
elif parameters['SYSTEM']['noncolin']:
for site in self.structure.sites:
for variable in namelist.keys():
namelist[variable][site.kinds] = site.get_magmom_coord(coord='spherical')[variable]
for variable, value in namelist.items():
value[site.kinds] = site.get_magmom_coord(coord='spherical')[variable]

return namelist


@calcfunction
def generate_structure_with_magmoms(input_structure=StructureData, input_magnetic_moments=orm.List):
def generate_structure_with_magmoms(input_structure: structures_classes, input_magnetic_moments: orm.List):
"""Generate a new structure with the magnetic moments for each site.
:param input_structure: the input structure to add the magnetic moments.
:param input_magnetic_moments: the magnetic moments for each site, represented as a float (see below).
@@ -87,4 +93,4 @@ def generate_structure_with_magmoms(input_structure=StructureData, input_magneti

output_structure = StructureData.from_mutable(mutable_structure, detect_kinds=True)

return output_structure
return output_structure
14 changes: 11 additions & 3 deletions src/aiida_quantumespresso/workflows/pw/relax.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,8 @@
from aiida.common import AttributeDict, exceptions
from aiida.common.lang import type_check
from aiida.engine import ToContext, WorkChain, append_, if_, while_
from aiida.plugins import CalculationFactory, WorkflowFactory
from aiida.orm import StructureData as LegacyStructureData
from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory

from aiida_quantumespresso.common.types import RelaxType
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
@@ -14,6 +15,13 @@
PwCalculation = CalculationFactory('quantumespresso.pw')
PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base')

try:
StructureData = DataFactory('atomistic.structure')
except exceptions.MissingEntryPointError:
structures_classes = (LegacyStructureData,)
else:
structures_classes = (LegacyStructureData, StructureData)


def validate_inputs(inputs, _):
"""Validate the top level namespace."""
@@ -38,7 +46,7 @@ def define(cls, spec):
exclude=('clean_workdir', 'pw.structure', 'pw.parent_folder'),
namespace_options={'required': False, 'populate_defaults': False,
'help': 'Inputs for the `PwBaseWorkChain` for the final scf.'})
spec.input('structure', valid_type=orm.StructureData, help='The inputs structure.')
spec.input('structure', valid_type=structures_classes, help='The inputs structure.')
spec.input('meta_convergence', valid_type=orm.Bool, default=lambda: orm.Bool(True),
help='If `True` the workchain will perform a meta-convergence on the cell volume.')
spec.input('max_meta_convergence_iterations', valid_type=orm.Int, default=lambda: orm.Int(5),
@@ -65,7 +73,7 @@ def define(cls, spec):
spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_FINAL_SCF',
message='the final scf PwBaseWorkChain sub process failed')
spec.expose_outputs(PwBaseWorkChain, exclude=('output_structure',))
spec.output('output_structure', valid_type=orm.StructureData, required=False,
spec.output('output_structure', valid_type=structures_classes, required=False,
help='The successfully relaxed structure.')
# yapf: enable

0 comments on commit 22d5ecf

Please sign in to comment.