Skip to content

Commit

Permalink
introducing support in bands.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mikibonacci committed Dec 5, 2024
1 parent 22d5ecf commit 1dfd4ed
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
"""Workchain to compute a band structure for a given structure using Quantum ESPRESSO pw.x."""
from aiida import orm
from aiida.common import AttributeDict
from aiida.common import AttributeDict, exceptions
from aiida.engine import ToContext, WorkChain, if_
from aiida.orm import StructureData as LegacyStructureData
from aiida.plugins import DataFactory

from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
Expand All @@ -11,6 +13,13 @@

from ..protocols.utils import ProtocolMixin

try:
StructureData = DataFactory('atomistic.structure')
except exceptions.MissingEntryPointError:
structures_classes = (LegacyStructureData,)
else:
structures_classes = (LegacyStructureData, StructureData)


def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument
"""Validate the inputs of the entire input namespace."""
Expand Down Expand Up @@ -61,7 +70,7 @@ def define(cls, spec):
spec.expose_inputs(PwBaseWorkChain, namespace='bands',
exclude=('clean_workdir', 'pw.structure', 'pw.kpoints', 'pw.kpoints_distance', 'pw.parent_folder'),
namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the BANDS calculation.'})
spec.input('structure', valid_type=orm.StructureData, help='The inputs structure.')
spec.input('structure', valid_type=structures_classes, help='The inputs structure.')
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.')
spec.input('nbands_factor', valid_type=orm.Float, required=False,
Expand Down Expand Up @@ -97,7 +106,7 @@ def define(cls, spec):
message='The scf PwBasexWorkChain sub process failed')
spec.exit_code(403, 'ERROR_SUB_PROCESS_FAILED_BANDS',
message='The bands PwBasexWorkChain sub process failed')
spec.output('primitive_structure', valid_type=orm.StructureData,
spec.output('primitive_structure', valid_type=structures_classes,
required=False,
help='The normalized and primitivized structure for which the bands are computed.')
spec.output('seekpath_parameters', valid_type=orm.Dict,
Expand Down

0 comments on commit 1dfd4ed

Please sign in to comment.