Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy fixes for engine module #6641

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ repos:
src/aiida/cmdline/utils/echo.py|
src/aiida/common/extendeddicts.py|
src/aiida/common/utils.py|
src/aiida/engine/daemon/execmanager.py|
src/aiida/engine/processes/calcjobs/manager.py|
src/aiida/engine/processes/calcjobs/monitors.py|
src/aiida/engine/processes/calcjobs/tasks.py|
Expand Down
8 changes: 5 additions & 3 deletions src/aiida/common/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,13 @@ class CalcInfo(DefaultFieldsAttributeDict):
)

if TYPE_CHECKING:
from aiida.orm.nodes.process.calculation.calcjob import RetrievedList

job_environment: None | dict[str, str]
email: None | str
email_on_started: bool
email_on_terminated: bool
uuid: None | str
uuid: str
prepend_text: None | str
append_text: None | str
num_machines: None | int
Expand All @@ -154,8 +156,8 @@ class CalcInfo(DefaultFieldsAttributeDict):
max_wallclock_seconds: None | int
max_memory_kb: None | int
rerunnable: bool
retrieve_list: None | list[str | tuple[str, str, str]]
retrieve_temporary_list: None | list[str | tuple[str, str, str]]
retrieve_list: RetrievedList
retrieve_temporary_list: RetrievedList
local_copy_list: None | list[tuple[str, str, str]]
remote_copy_list: None | list[tuple[str, str, str]]
remote_symlink_list: None | list[tuple[str, str, str]]
Expand Down
74 changes: 47 additions & 27 deletions src/aiida/engine/daemon/execmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from logging import LoggerAdapter
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import Mapping as MappingType

from aiida.common import AIIDA_LOGGER, exceptions
Expand All @@ -29,12 +29,13 @@
from aiida.common.links import LinkType
from aiida.engine.processes.exit_code import ExitCode
from aiida.manage.configuration import get_config_option
from aiida.orm import CalcJobNode, Code, FolderData, Node, PortableCode, RemoteData, load_node
from aiida.orm import CalcJobNode, Code, Computer, FolderData, Node, PortableCode, RemoteData, load_node
from aiida.orm.utils.log import get_dblogger_extra
from aiida.repository.common import FileType
from aiida.schedulers.datastructures import JobState

if TYPE_CHECKING:
from aiida.orm.nodes.process.calculation.calcjob import RetrievedList
from aiida.transports import Transport

REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found'
Expand Down Expand Up @@ -84,12 +85,13 @@ def upload_calculation(
link_label = 'remote_folder'
if node.base.links.get_outgoing(RemoteData, link_label_filter=link_label).first():
EXEC_LOGGER.warning(f'CalcJobNode<{node.pk}> already has a `{link_label}` output: skipping upload')
return calc_info
return None

computer = node.computer
# cast since certain the CalcJobNode must have a computer attach to
computer = cast(Computer, node.computer)

codes_info = calc_info.codes_info
input_codes = [load_node(_.code_uuid, sub_classes=(Code,)) for _ in codes_info]
input_codes = [load_node(_.code_uuid, sub_classes=(Code,)) for _ in codes_info] if codes_info else []

logger_extra = get_dblogger_extra(node)
transport.set_logger_extra(logger_extra)
Expand Down Expand Up @@ -182,7 +184,7 @@ def upload_calculation(
# Since the content of the node could potentially be binary, we read the raw bytes and pass them on
for filename in filenames:
with NamedTemporaryFile(mode='wb+') as handle:
content = code.base.repository.get_object_content(Path(root) / filename, mode='rb')
content = code.base.repository.get_object_content(root / filename, mode='rb')
handle.write(content)
handle.flush()
transport.put(handle.name, str(workdir.joinpath(root, filename)))
Expand Down Expand Up @@ -222,7 +224,7 @@ def upload_calculation(
if dry_run:
if remote_copy_list:
filepath = os.path.join(str(workdir), '_aiida_remote_copy_list.txt')
with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment]
with open(filepath, 'w', encoding='utf-8') as handle:
for _, remote_abs_path, dest_rel_path in remote_copy_list:
handle.write(
f'would have copied {remote_abs_path} to {dest_rel_path} in working '
Expand All @@ -231,7 +233,7 @@ def upload_calculation(

if remote_symlink_list:
filepath = os.path.join(str(workdir), '_aiida_remote_symlink_list.txt')
with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment]
with open(filepath, 'w', encoding='utf-8') as handle:
for _, remote_abs_path, dest_rel_path in remote_symlink_list:
handle.write(
f'would have created symlinks from {remote_abs_path} to {dest_rel_path} in working'
Expand Down Expand Up @@ -265,7 +267,7 @@ def upload_calculation(
if relpath not in provenance_exclude_list and all(
dirname not in provenance_exclude_list for dirname in dirnames
):
with open(filepath, 'rb') as handle: # type: ignore[assignment]
with open(filepath, 'rb') as handle:
node.base.repository._repository.put_object_from_filelike(handle, relpath)

# Since the node is already stored, we cannot use the normal repository interface since it will raise a
Expand Down Expand Up @@ -333,14 +335,15 @@ def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir:
for uuid, filename, target in local_copy_list:
logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}')

data_node = None
try:
data_node = load_node(uuid=uuid)
except exceptions.NotExistent:
data_node = _find_data_node(inputs, uuid) if inputs else None

if data_node is None:
logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`')
continue
finally:
if data_node is None:
logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`')
continue

# The transport class can only copy files directly from the file system, so the files in the source node's repo
# have to first be copied to a temporary directory on disk.
Expand Down Expand Up @@ -410,20 +413,27 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str |
if job_id is not None:
return job_id

scheduler = calculation.computer.get_scheduler()
computer = cast(Computer, calculation.computer)
scheduler = computer.get_scheduler()
scheduler.set_transport(transport)

submit_script_filename = calculation.get_option('submit_script_filename')
# metadata.options.submit_script_filename of CalcJob inputs
submit_script_filename: str = cast(str, calculation.get_option('submit_script_filename'))
workdir = calculation.get_remote_workdir()
result = scheduler.submit_job(workdir, submit_script_filename)
if workdir is not None:
result = scheduler.submit_job(workdir, submit_script_filename)
else:
# FIXME: Require inner exit_code for remote_workdir of calculation is not set
# Return ExitCode since it is what user can fix
return ExitCode(-1)

if isinstance(result, str):
calculation.set_job_id(result)

return result


def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None | ExitCode:
"""Stash files from the working directory of a completed calculation to a permanent remote folder.

After a calculation has been completed, optionally stash files from the work directory to a storage location on the
Expand All @@ -439,23 +449,29 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:

logger_extra = get_dblogger_extra(calculation)

stash_options = calculation.get_option('stash')
stash_options = cast(dict[str, Any], calculation.get_option('stash'))
stash_mode = stash_options.get('mode', StashMode.COPY.value)
source_list = stash_options.get('source_list', [])

if not source_list:
return
return None

if stash_mode != StashMode.COPY.value:
EXEC_LOGGER.warning(f'stashing mode {stash_mode} is not implemented yet.')
return
return None

cls = RemoteStashFolderData

EXEC_LOGGER.debug(f'stashing files for calculation<{calculation.pk}>: {source_list}', extra=logger_extra)

uuid = calculation.uuid
source_basepath = Path(calculation.get_remote_workdir())
workdir = calculation.get_remote_workdir()
if workdir is not None:
source_basepath = Path(workdir)
else:
# FIXME: Require inner exit_code for remote_workdir of calculation is not set
# Return ExitCode since it is what user can fix
return ExitCode(-1)
target_basepath = Path(stash_options['target_base']) / uuid[:2] / uuid[2:4] / uuid[4:]

for source_filename in source_list:
Expand Down Expand Up @@ -487,6 +503,8 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
).store()
remote_stash.base.links.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash')

return None


def retrieve_calculation(
calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str
Expand Down Expand Up @@ -518,7 +536,7 @@ def retrieve_calculation(
EXEC_LOGGER.warning(
f'CalcJobNode<{calculation.pk}> already has a `{link_label}` output folder: skipping retrieval'
)
return
return None

# Create the FolderData node into which to store the files that are to be retrieved
retrieved_files = FolderData()
Expand Down Expand Up @@ -567,7 +585,8 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None:
return

# Get the scheduler plugin class and initialize it with the correct transport
scheduler = calculation.computer.get_scheduler()
computer = cast(Computer, calculation.computer)
scheduler = computer.get_scheduler()
scheduler.set_transport(transport)

# Call the proper kill method for the job ID of this calculation
Expand All @@ -576,7 +595,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None:
if result is not True:
# Failed to kill because the job might have already been completed
running_jobs = scheduler.get_jobs(jobs=[job_id], as_dict=True)
job = running_jobs.get(job_id, None)
job = running_jobs.get(job_id, None) # type: ignore[union-attr]

# If the job is returned it is still running and the kill really failed, so we raise
if job is not None and job.job_state != JobState.DONE:
Expand All @@ -591,7 +610,7 @@ def retrieve_files_from_list(
calculation: CalcJobNode,
transport: Transport,
folder: str,
retrieve_list: List[Union[str, Tuple[str, str, int], list]],
retrieve_list: RetrievedList,
) -> None:
"""Retrieve all the files in the retrieve_list from the remote into the
local folder instance through the transport. The entries in the retrieve_list
Expand Down Expand Up @@ -621,7 +640,7 @@ def retrieve_files_from_list(
tmp_rname, tmp_lname, depth = item
# if there are more than one file I do something differently
if transport.has_magic(tmp_rname):
remote_names = transport.glob(str(workdir.joinpath(tmp_rname)))
remote_names = transport.glob(str(workdir / tmp_rname))
local_names = []
for rem in remote_names:
# get the relative path so to make local_names relative
Expand All @@ -633,6 +652,7 @@ def retrieve_files_from_list(
local_names.append(os.path.sep.join([tmp_lname] + to_append))
else:
remote_names = [tmp_rname]
# FIXME: will except if depth is none
to_append = tmp_rname.split(os.path.sep)[-depth:] if depth > 0 else []
local_names = [os.path.sep.join([tmp_lname] + to_append)]
if depth is None or depth > 1: # create directories in the folder, if needed
Expand All @@ -641,7 +661,7 @@ def retrieve_files_from_list(
if not os.path.exists(new_folder):
os.makedirs(new_folder)
else:
abs_item = item if item.startswith('/') else str(workdir.joinpath(item))
abs_item = item if item.startswith('/') else str(workdir / item)

if transport.has_magic(abs_item):
remote_names = transport.glob(abs_item)
Expand Down
16 changes: 10 additions & 6 deletions src/aiida/orm/nodes/process/calculation/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
###########################################################################
"""Module with `Node` sub class for calculation job processes."""

from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, MutableSequence, Optional, Tuple, Type, Union

from aiida.common import exceptions
from aiida.common.datastructures import CalcJobState
Expand All @@ -30,6 +32,8 @@

__all__ = ('CalcJobNode',)

RetrievedList = MutableSequence[str | tuple[str, str, int | None]] | None


class CalcJobNodeCaching(ProcessNodeCaching):
"""Interface to control caching of a node instance."""
Expand Down Expand Up @@ -274,7 +278,7 @@ def get_remote_workdir(self) -> Optional[str]:
return self.base.attributes.get(self.REMOTE_WORKDIR_KEY, None)

@staticmethod
def _validate_retrieval_directive(directives: Sequence[Union[str, Tuple[str, str, str]]]) -> None:
def _validate_retrieval_directive(directives: RetrievedList) -> None:
"""Validate a list or tuple of file retrieval directives.

:param directives: a list or tuple of file retrieval directives
Expand All @@ -301,7 +305,7 @@ def _validate_retrieval_directive(directives: Sequence[Union[str, Tuple[str, str
if not isinstance(directive[2], (int, type(None))):
raise ValueError('invalid directive, third element has to be an integer representing the depth')

def set_retrieve_list(self, retrieve_list: Sequence[Union[str, Tuple[str, str, str]]]) -> None:
def set_retrieve_list(self, retrieve_list: RetrievedList) -> None:
"""Set the retrieve list.

This list of directives will instruct the daemon what files to retrieve after the calculation has completed.
Expand All @@ -312,14 +316,14 @@ def set_retrieve_list(self, retrieve_list: Sequence[Union[str, Tuple[str, str, s
self._validate_retrieval_directive(retrieve_list)
self.base.attributes.set(self.RETRIEVE_LIST_KEY, retrieve_list)

def get_retrieve_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]:
def get_retrieve_list(self) -> RetrievedList:
"""Return the list of files/directories to be retrieved on the cluster after the calculation has completed.

:return: a list of file directives
"""
return self.base.attributes.get(self.RETRIEVE_LIST_KEY, None)

def set_retrieve_temporary_list(self, retrieve_temporary_list: Sequence[Union[str, Tuple[str, str, str]]]) -> None:
def set_retrieve_temporary_list(self, retrieve_temporary_list: RetrievedList) -> None:
"""Set the retrieve temporary list.

The retrieve temporary list stores files that are retrieved after completion and made available during parsing
Expand All @@ -330,7 +334,7 @@ def set_retrieve_temporary_list(self, retrieve_temporary_list: Sequence[Union[st
self._validate_retrieval_directive(retrieve_temporary_list)
self.base.attributes.set(self.RETRIEVE_TEMPORARY_LIST_KEY, retrieve_temporary_list)

def get_retrieve_temporary_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]:
def get_retrieve_temporary_list(self) -> RetrievedList:
"""Return list of files to be retrieved from the cluster which will be available during parsing.

:return: a list of file directives
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/orm/nodes/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def get_object(self, path: FilePath | None = None) -> File:
return self._repository.get_object(path)

@t.overload
def get_object_content(self, path: str, mode: t.Literal['r']) -> str: ...
def get_object_content(self, path: FilePath, mode: t.Literal['r']) -> str: ...

@t.overload
def get_object_content(self, path: str, mode: t.Literal['rb']) -> bytes: ...
def get_object_content(self, path: FilePath, mode: t.Literal['rb']) -> bytes: ...

def get_object_content(self, path: str, mode: t.Literal['r', 'rb'] = 'r') -> str | bytes:
def get_object_content(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> str | bytes:
"""Return the content of a object identified by key.

:param path: the relative path of the object within the repository.
Expand Down
Loading