Skip to content

Commit

Permalink
CLI: Simplify cli.utils.validate_hubbard_parameters and add tests (#…
Browse files Browse the repository at this point in the history
…909)

The `validate_hubbard_parameters` method defined `hubbard_file_pk` as
an argument which can take a `SinglefileData` with Hubbard parameters or
a pk. However, the only methods calling it, are CLI commands that pass
in a value provided by the `options.HUBBARD_FILE` reusable option which
already loads the node corresponding to the identifier specified in the
command line arguments. It also ensures that it is a `SinglefileData`.

Therefore, the signature can be simplified to just accept a node and do
away with the manual loading and type checking. This also allows to
change the return type to `None` as the node is no longer loaded by the
validation method but is passed in as an argument.

A test is added that runs the `calculation launch pw` command with the
`--hubbard-u` or `--hubbard-file` options, both with valid and invalid
options to test the validation function.

Co-authored-by: Sebastiaan Huber <[email protected]>
  • Loading branch information
bastonero and sphuber authored Apr 13, 2023
1 parent 89d39a4 commit 74d25d1
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 40 deletions.
6 changes: 2 additions & 4 deletions src/aiida_quantumespresso/cli/calculations/pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
@decorators.with_dbenv()
def launch_calculation(
code, structure, pseudo_family, kpoints_mesh, ecutwfc, ecutrho, hubbard_u, hubbard_v, hubbard_file_pk,
code, structure, pseudo_family, kpoints_mesh, ecutwfc, ecutrho, hubbard_u, hubbard_v, hubbard_file,
starting_magnetization, smearing, max_num_machines, max_wallclock_seconds, with_mpi, daemon, parent_folder, dry_run,
mode, unfolded_kpoints
):
Expand All @@ -73,9 +73,7 @@ def launch_calculation(
raise click.BadParameter(f"calculation '{mode}' requires a parent folder", param_hint='--parent-folder')

try:
hubbard_file = validate.validate_hubbard_parameters(
structure, parameters, hubbard_u, hubbard_v, hubbard_file_pk
)
validate.validate_hubbard_parameters(structure, parameters, hubbard_u, hubbard_v, hubbard_file)
except ValueError as exception:
raise click.BadParameter(str(exception))

Expand Down
2 changes: 1 addition & 1 deletion src/aiida_quantumespresso/cli/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def convert(self, value, param, ctx):
HUBBARD_FILE = OverridableOption(
'-H',
'--hubbard-file',
'hubbard_file_pk',
'hubbard_file',
type=types.DataParamType(sub_classes=('aiida.data:core.singlefile',)),
help='SinglefileData containing Hubbard parameters from a HpCalculation to use as input for Hubbard V.'
)
Expand Down
28 changes: 5 additions & 23 deletions src/aiida_quantumespresso/cli/utils/validate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
"""Utility functions for validation of command line interface parameter inputs."""
from aiida.cmdline.utils import decorators
from aiida.common import exceptions
import click


Expand Down Expand Up @@ -36,35 +35,20 @@ def validate_kpoints_mesh(ctx, param, value):


@decorators.with_dbenv()
def validate_hubbard_parameters(structure, parameters, hubbard_u=None, hubbard_v=None, hubbard_file_pk=None):
def validate_hubbard_parameters(structure, parameters, hubbard_u=None, hubbard_v=None, hubbard_file=None):
"""Validate Hubbard input parameters and update the parameters input node accordingly.
If a valid hubbard_file_pk is provided, the node will be loaded and returned.
:param structure: the StructureData node that will be used in the inputs
:param parameters: the Dict node that will be used in the inputs
:param hubbard_u: the Hubbard U inputs values from the cli
:param hubbard_v: the Hubbard V inputs values from the cli
:param hubbard_file_pk: a pk referencing a SinglefileData with Hubbard parameters
:returns: the loaded SinglefileData node with Hubbard parameters if valid pk was defined, None otherwise
:param hubbard_file: a SinglefileData with Hubbard parameters
:raises ValueError: if the input is invalid
"""
from aiida.orm import SinglefileData, load_node

if len([value for value in [hubbard_u, hubbard_v, hubbard_file_pk] if value]) > 1:
raise ValueError('the hubbard_u, hubbard_v and hubbard_file_pk options are mutually exclusive')

hubbard_file = None
if len([value for value in [hubbard_u, hubbard_v, hubbard_file] if value]) > 1:
raise ValueError('the hubbard_u, hubbard_v and hubbard_file options are mutually exclusive')

if hubbard_file_pk:

try:
hubbard_file = load_node(pk=hubbard_file_pk)
except exceptions.NotExistent as exc:
raise ValueError(f'{hubbard_file_pk} is not a valid pk') from exc
else:
if not isinstance(hubbard_file, SinglefileData):
raise ValueError(f'Node<{hubbard_file_pk}> is not a SinglefileData but {type(hubbard_file)}')
if hubbard_file:

parameters['SYSTEM']['lda_plus_u'] = True
parameters['SYSTEM']['lda_plus_u_kind'] = 2
Expand Down Expand Up @@ -95,8 +79,6 @@ def validate_hubbard_parameters(structure, parameters, hubbard_u=None, hubbard_v
for kind, value in hubbard_u:
parameters['SYSTEM']['hubbard_u'][kind] = value

return hubbard_file


def validate_starting_magnetization(structure, parameters, starting_magnetization=None):
"""Validate starting magnetization parameters and update the parameters input node accordingly.
Expand Down
6 changes: 2 additions & 4 deletions src/aiida_quantumespresso/cli/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
@options.DAEMON()
@decorators.with_dbenv()
def launch_workflow(
code, structure, pseudo_family, kpoints_distance, ecutwfc, ecutrho, hubbard_u, hubbard_v, hubbard_file_pk,
code, structure, pseudo_family, kpoints_distance, ecutwfc, ecutrho, hubbard_u, hubbard_v, hubbard_file,
starting_magnetization, smearing, clean_workdir, max_num_machines, max_wallclock_seconds, with_mpi, daemon
):
"""Run a `PwBandsWorkChain`."""
Expand All @@ -53,9 +53,7 @@ def launch_workflow(
}

try:
hubbard_file = validate.validate_hubbard_parameters(
structure, parameters, hubbard_u, hubbard_v, hubbard_file_pk
)
validate.validate_hubbard_parameters(structure, parameters, hubbard_u, hubbard_v, hubbard_file)
except ValueError as exception:
raise click.BadParameter(str(exception))

Expand Down
6 changes: 2 additions & 4 deletions src/aiida_quantumespresso/cli/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
@options.DAEMON()
@decorators.with_dbenv()
def launch_workflow(
code, structure, pseudo_family, kpoints_distance, ecutwfc, ecutrho, hubbard_u, hubbard_v, hubbard_file_pk,
code, structure, pseudo_family, kpoints_distance, ecutwfc, ecutrho, hubbard_u, hubbard_v, hubbard_file,
starting_magnetization, smearing, clean_workdir, max_num_machines, max_wallclock_seconds, with_mpi, daemon
):
"""Run a `PwBaseWorkChain`."""
Expand All @@ -49,9 +49,7 @@ def launch_workflow(
}

try:
hubbard_file = validate.validate_hubbard_parameters(
structure, parameters, hubbard_u, hubbard_v, hubbard_file_pk
)
validate.validate_hubbard_parameters(structure, parameters, hubbard_u, hubbard_v, hubbard_file)
except ValueError as exception:
raise click.BadParameter(str(exception))

Expand Down
6 changes: 2 additions & 4 deletions src/aiida_quantumespresso/cli/workflows/pw/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
@decorators.with_dbenv()
def launch_workflow(
code, structure, pseudo_family, kpoints_distance, ecutwfc, ecutrho, hubbard_u, hubbard_v, hubbard_file_pk,
code, structure, pseudo_family, kpoints_distance, ecutwfc, ecutrho, hubbard_u, hubbard_v, hubbard_file,
starting_magnetization, smearing, clean_workdir, max_num_machines, max_wallclock_seconds, with_mpi, daemon,
final_scf
):
Expand All @@ -61,9 +61,7 @@ def launch_workflow(
}

try:
hubbard_file = validate.validate_hubbard_parameters(
structure, parameters, hubbard_u, hubbard_v, hubbard_file_pk
)
validate.validate_hubbard_parameters(structure, parameters, hubbard_u, hubbard_v, hubbard_file)
except ValueError as exception:
raise click.BadParameter(str(exception))

Expand Down
45 changes: 45 additions & 0 deletions tests/cli/calculations/test_pw.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# -*- coding: utf-8 -*-
"""Tests for the ``calculation launch pw`` command."""
import re

import pytest

from aiida_quantumespresso.cli.calculations.pw import launch_calculation


Expand All @@ -8,3 +12,44 @@ def test_command_base(run_cli_process_launch_command, fixture_code, sssp):
code = fixture_code('quantumespresso.pw').store()
options = ['-X', code.full_label, '-F', sssp.label]
run_cli_process_launch_command(launch_calculation, options=options)


# yapf: disable
@pytest.mark.parametrize(('cmd_options', 'match'), (
(
['--hubbard-u', 'Mg'],
".*Option '--hubbard-u' requires 2 arguments.*"),
(
['--hubbard-u', 'Mg', '5.0'],
'.*kinds in the specified Hubbard U is not a strict subset of the structure kinds.*'),
(
['--hubbard-file', '1000000'],
'.*no SinglefileData found with ID*'
),
))
# yapf: enable
def test_invalid_hubbard_parameters(run_cli_process_launch_command, fixture_code, sssp, cmd_options, match):
"""Test invoking the calculation launch command with invalid Hubbard inputs."""
code = fixture_code('quantumespresso.pw').store()
options = ['-X', code.full_label, '-F', sssp.label] + cmd_options
result = run_cli_process_launch_command(launch_calculation, options=options, raises=ValueError)
assert re.match(match, ' '.join(result.output_lines))


@pytest.mark.usefixtures('aiida_profile')
def test_valid_hubbard_parameters(run_cli_process_launch_command, fixture_code, sssp):
"""Test invoking the calculation launch command with valid Hubbard inputs."""
import io

from aiida.orm import SinglefileData

code = fixture_code('quantumespresso.pw').store()

options = ['-X', code.full_label, '-F', sssp.label, '--hubbard-u', 'Si', '5.0']
run_cli_process_launch_command(launch_calculation, options=options)

content_original = 'for sure some correct Hubbard parameters'
filepk = SinglefileData(io.StringIO(content_original)).store().pk

options = ['-X', code.full_label, '-F', sssp.label, '--hubbard-file', filepk]
run_cli_process_launch_command(launch_calculation, options=options)

0 comments on commit 74d25d1

Please sign in to comment.