Skip to content

Commit

Permalink
Fixing some type errors in the input generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Chico committed Apr 24, 2024
1 parent affe42f commit 7fd875c
Showing 1 changed file with 52 additions and 44 deletions.
96 changes: 52 additions & 44 deletions src/aiida_lammps/parsers/inputfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
fixes block is never called, on the other hand the control block is always
called since it is necessary for the functioning of LAMMPS.
"""

from builtins import ValueError
import json
import os
import re
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Tuple

from aiida import orm
import numpy as np
Expand Down Expand Up @@ -274,7 +275,7 @@ def write_structure_block(
parameters_structure: Dict[str, Any],
structure: orm.StructureData,
structure_filename: str,
) -> Union[str, list]:
) -> Tuple[str, List[str]]:
"""
Generate the input block with the structure options.
Expand All @@ -285,19 +286,19 @@ def write_structure_block(
:param parameters_structure: set of user defined parameters relating to the
structure.
:type parameters_structure: dict
:type parameters_structure: Dict[str, Any]
:param structure: structure that will be studied
:type structure: orm.StructureData
:param structure_filename: name of the file where the structure will be
written so that LAMMPS can read it
:type structure_filename: str
:return: block with the structural information and list of groups present
:rtype: Union[str, list]
:rtype: Tuple[str, List[str]]
"""

group_names = []
group_names: List[str] = []

kind_name_id_map = {}
kind_name_id_map: Dict[str, int] = {}
for site in structure.sites:
if site.kind_name not in kind_name_id_map:
kind_name_id_map[site.kind_name] = len(kind_name_id_map) + 1
Expand Down Expand Up @@ -455,7 +456,7 @@ def write_final_variables_block(
return variables_block


def generate_velocity_string(parameters_velocity: Dict[str, Any]) -> str:
def generate_velocity_string(parameters_velocity: List[Dict[str, Any]]) -> str:
"""
Generate the velocity string for the MD block.
Expand All @@ -474,7 +475,7 @@ def generate_velocity_string(parameters_velocity: Dict[str, Any]) -> str:
options += f'velocity {entry.get("group", "all")} create'
options += f' {entry["create"].get("temp")}'
options += (
f' {entry["create"].get("seed", np.random.randint(1e4))} {_options}\n'
f' {entry["create"].get("seed", np.random.randint(10000))} {_options}\n'
)
if "set" in entry:
options += f'velocity {entry.get("group", "all")} set'
Expand Down Expand Up @@ -611,22 +612,25 @@ def generate_integration_options(
for _option in temperature_options:
if _option in integration_parameters:
_value = integration_parameters.get(_option)
_value = [str(val) for val in _value]
options += f' {_option} {" ".join(_value) if isinstance(_value, list) else _value} '
if _value:
_value = [str(val) for val in _value]
options += f' {_option} {" ".join(_value) if isinstance(_value, list) else _value} '
# Set the options that depend on the pressure
if style in pressure_dependent:
for _option in pressure_options:
if _option in integration_parameters:
_value = integration_parameters.get(_option)
_value = [str(val) for val in _value]
options += f' {_option} {" ".join(_value) if isinstance(_value, list) else _value} '
if _value:
_value = [str(val) for val in _value]
options += f' {_option} {" ".join(_value) if isinstance(_value, list) else _value} '
# Set the options that depend on the 'uef' parameters
if style in uef_dependent:
for _option in uef_options:
if _option in integration_parameters:
_value = integration_parameters.get(_option)
_value = [str(val) for val in _value]
options += f' {_option} {" ".join(_value) if isinstance(_value, list) else _value} '
if _value:
_value = [str(val) for val in _value]
options += f' {_option} {" ".join(_value) if isinstance(_value, list) else _value} '
# Set the options that depend on the 'nve/limit' parameters
if style in ["nve/limit"]:
options += f' {integration_parameters.get("xmax", 0.1)} '
Expand Down Expand Up @@ -722,8 +726,8 @@ def write_dump_block(
parameters_dump: Dict[str, Any],
trajectory_filename: str,
atom_style: str,
kind_symbols: List[str],
parameters_compute: Optional[Dict[str, Any]] = None,
kind_symbols: Optional[List[str]] = None,
) -> str:
"""Generate the block with dumps commands.
Expand Down Expand Up @@ -752,19 +756,20 @@ def write_dump_block(

computes_list = []

for key, value in parameters_compute.items():
for entry in value:
_locality = _compute_variables[key]["locality"]
_printable = _compute_variables[key]["printable"]

if _locality == "local" and _printable:
computes_list.append(
generate_printing_string(
name=key,
group=entry["group"],
calculation_type="compute",
if parameters_compute:
for key, value in parameters_compute.items():
for entry in value:
_locality = _compute_variables[key]["locality"]
_printable = _compute_variables[key]["printable"]

if _locality == "local" and _printable:
computes_list.append(
generate_printing_string(
name=key,
group=entry["group"],
calculation_type="compute",
)
)
)

num_double = len(list(flatten([compute.split() for compute in computes_list])))
num_double += 3
Expand All @@ -787,7 +792,7 @@ def write_dump_block(
def write_thermo_block(
parameters_thermo: Dict[str, Any],
parameters_compute: Optional[Dict[str, Any]] = None,
) -> Union[str, List[str]]:
) -> Tuple[str, List[str]]:
"""Generate the block with the thermo command.
This will take all the global computes which were generated during the calculation
Expand All @@ -814,19 +819,20 @@ def write_thermo_block(

computes_list = []

for key, value in parameters_compute.items():
for entry in value:
_locality = _compute_variables[key]["locality"]
_printable = _compute_variables[key]["printable"]

if _locality == "global" and _printable:
computes_list.append(
generate_printing_string(
name=key,
group=entry["group"],
calculation_type="compute",
if parameters_compute:
for key, value in parameters_compute.items():
for entry in value:
_locality = _compute_variables[key]["locality"]
_printable = _compute_variables[key]["printable"]

if _locality == "global" and _printable:
computes_list.append(
generate_printing_string(
name=key,
group=entry["group"],
calculation_type="compute",
)
)
)

computes_printing = parameters_thermo.get("thermo_printing", None)

Expand Down Expand Up @@ -978,7 +984,7 @@ def generate_printing_string(
return " ".join(_string)


def generate_id_tag(name: Optional[str] = None, group: Optional[str] = None) -> str:
def generate_id_tag(name: str, group: str) -> str:
"""Generate an id tag for fixes and/or computes.
To standardize the naming of computes and/or fixes and to ensure that one
Expand Down Expand Up @@ -1017,9 +1023,11 @@ def join_keywords(value: List[Any]) -> str:

return " ".join(
[
f"{entry['keyword']} {entry['value']}"
if isinstance(entry, dict)
else f"{entry}"
(
f"{entry['keyword']} {entry['value']}"
if isinstance(entry, dict)
else f"{entry}"
)
for entry in value
]
)

0 comments on commit 7fd875c

Please sign in to comment.