diff --git a/aiida_workgraph/cli/__init__.py b/aiida_workgraph/cli/__init__.py index db1f1f89..437e58d7 100644 --- a/aiida_workgraph/cli/__init__.py +++ b/aiida_workgraph/cli/__init__.py @@ -5,6 +5,7 @@ from aiida_workgraph.cli import cmd_graph from aiida_workgraph.cli import cmd_web from aiida_workgraph.cli import cmd_task +from aiida_workgraph.cli import cmd_scheduler -__all__ = ["cmd_graph", "cmd_web", "cmd_task"] +__all__ = ["cmd_graph", "cmd_web", "cmd_task", "cmd_scheduler"] diff --git a/aiida_workgraph/cli/cmd_scheduler.py b/aiida_workgraph/cli/cmd_scheduler.py new file mode 100644 index 00000000..52cb54d3 --- /dev/null +++ b/aiida_workgraph/cli/cmd_scheduler.py @@ -0,0 +1,163 @@ +from aiida_workgraph.cli.cmd_workgraph import workgraph +import click +from aiida.cmdline.utils import decorators, echo +from aiida.cmdline.commands.cmd_daemon import validate_daemon_workers +from aiida.cmdline.params import options +from aiida_workgraph.engine.scheduler.client import get_scheduler_client +import sys + + +@workgraph.group("scheduler") +def scheduler(): + """Commands to manage the scheduler process.""" + + +@scheduler.command() +def worker(): + """Start the scheduler application.""" + from aiida_workgraph.engine.scheduler.client import start_scheduler_worker + + click.echo("Starting the scheduler worker...") + + start_scheduler_worker() + + +@scheduler.command() +@click.option("--foreground", is_flag=True, help="Run in foreground.") +@click.argument("number", required=False, type=int, callback=validate_daemon_workers) +@options.TIMEOUT(default=None, required=False, type=int) +@decorators.with_dbenv() +@decorators.requires_broker +@decorators.check_circus_zmq_version +def start(foreground, number, timeout): + """Start the scheduler application.""" + from aiida_workgraph.engine.scheduler.client import start_scheduler_process + + click.echo("Starting the scheduler process...") + + client = get_scheduler_client() + client.start_daemon(number_workers=number, foreground=foreground, timeout=timeout) + start_scheduler_process(number) + + +@scheduler.command() +@click.option("--no-wait", is_flag=True, help="Do not wait for confirmation.") +@click.option("--all", "all_profiles", is_flag=True, help="Stop all daemons.") +@options.TIMEOUT(default=None, required=False, type=int) +@decorators.requires_broker +@click.pass_context +def stop(ctx, no_wait, all_profiles, timeout): + """Stop the scheduler daemon. + + Returns exit code 0 if the daemon was shut down successfully (or was not running), non-zero if there was an error. + """ + if all_profiles is True: + profiles = [ + profile + for profile in ctx.obj.config.profiles + if not profile.is_test_profile + ] + else: + profiles = [ctx.obj.profile] + + for profile in profiles: + echo.echo("Profile: ", fg=echo.COLORS["report"], bold=True, nl=False) + echo.echo(f"{profile.name}", bold=True) + echo.echo("Stopping the daemon... ", nl=False) + try: + client = get_scheduler_client() + client.stop_daemon(wait=not no_wait, timeout=timeout) + except Exception as exception: + echo.echo_error(f"Failed to stop the daemon: {exception}") + + +@scheduler.command(hidden=True) +@click.option("--foreground", is_flag=True, help="Run in foreground.") +@click.argument("number", required=False, type=int, callback=validate_daemon_workers) +@decorators.with_dbenv() +@decorators.requires_broker +@decorators.check_circus_zmq_version +def start_circus(foreground, number): + """This will actually launch the circus daemon, either daemonized in the background or in the foreground. + + If run in the foreground all logs are redirected to stdout. + + .. note:: this should not be called directly from the commandline! + """ + + get_scheduler_client()._start_daemon(number_workers=number, foreground=foreground) + + +@scheduler.command() +@click.option("--all", "all_profiles", is_flag=True, help="Show status of all daemons.") +@options.TIMEOUT(default=None, required=False, type=int) +@click.pass_context +@decorators.requires_loaded_profile() +@decorators.requires_broker +def status(ctx, all_profiles, timeout): + """Print the status of the scheduler daemon. + + Returns exit code 0 if all requested daemons are running, else exit code 3. + """ + from tabulate import tabulate + + from aiida.cmdline.utils.common import format_local_time + from aiida.engine.daemon.client import DaemonException + + if all_profiles is True: + profiles = [ + profile + for profile in ctx.obj.config.profiles + if not profile.is_test_profile + ] + else: + profiles = [ctx.obj.profile] + + daemons_running = [] + + for profile in profiles: + client = get_scheduler_client(profile.name) + echo.echo("Profile: ", fg=echo.COLORS["report"], bold=True, nl=False) + echo.echo(f"{profile.name}", bold=True) + + try: + client.get_status(timeout=timeout) + except DaemonException as exception: + echo.echo_error(str(exception)) + daemons_running.append(False) + continue + + worker_response = client.get_worker_info() + daemon_response = client.get_daemon_info() + + workers = [] + for pid, info in worker_response["info"].items(): + if isinstance(info, dict): + row = [ + pid, + info["mem"], + info["cpu"], + format_local_time(info["create_time"]), + ] + else: + row = [pid, "-", "-", "-"] + workers.append(row) + + if workers: + workers_info = tabulate( + workers, headers=["PID", "MEM %", "CPU %", "started"], tablefmt="simple" + ) + else: + workers_info = ( + "--> No workers are running. Use `verdi daemon incr` to start some!\n" + ) + + start_time = format_local_time(daemon_response["info"]["create_time"]) + echo.echo( + f'Daemon is running as PID {daemon_response["info"]["pid"]} since {start_time}\n' + f"Active workers [{len(workers)}]:\n{workers_info}\n" + "Use `verdi daemon [incr | decr] [num]` to increase / decrease the number of workers" + ) + + if not all(daemons_running): + sys.exit(3) diff --git a/aiida_workgraph/engine/launch.py b/aiida_workgraph/engine/launch.py new file mode 100644 index 00000000..25b8c1e8 --- /dev/null +++ b/aiida_workgraph/engine/launch.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import time +import typing as t + +from aiida.common import InvalidOperation +from aiida.common.log import AIIDA_LOGGER +from aiida.manage import manager +from aiida.orm import ProcessNode + +from aiida.engine.processes.builder import ProcessBuilder +from aiida.engine.processes.functions import get_stack_size +from aiida.engine.processes.process import Process +from aiida.engine.utils import prepare_inputs +from .utils import instantiate_process + +import signal +import sys + +from aiida.manage import get_manager + +__all__ = ("run_get_node", "submit") + +TYPE_RUN_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] +# run can also be process function, but it is not clear what type this should be +TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] +LOGGER = AIIDA_LOGGER.getChild("engine.launch") + + +""" +Note: I modified the run_get_node and submit functions to include the parent_pid argument. +This is necessary for keeping track of the provenance of the processes. + +""" + + +def run_get_node( + process_class, *args, **kwargs +) -> tuple[dict[str, t.Any] | None, "ProcessNode"]: + """Run the FunctionProcess with the supplied inputs in a local runner. + :param args: input arguments to construct the FunctionProcess + :param kwargs: input keyword arguments to construct the FunctionProcess + :return: tuple of the outputs of the process and the process node + """ + parent_pid = kwargs.pop("parent_pid", None) + frame_delta = 1000 + frame_count = get_stack_size() + stack_limit = sys.getrecursionlimit() + LOGGER.info( + "Executing process function, current stack status: %d frames of %d", + frame_count, + stack_limit, + ) + # If the current frame count is more than 80% of the stack limit, or comes within 200 frames, increase the + # stack limit by ``frame_delta``. + if frame_count > min(0.8 * stack_limit, stack_limit - 200): + LOGGER.warning( + "Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d", + frame_count, + stack_limit, + frame_delta, + ) + sys.setrecursionlimit(stack_limit + frame_delta) + manager = get_manager() + runner = manager.get_runner() + inputs = process_class.create_inputs(*args, **kwargs) + # Remove all the known inputs from the kwargs + for port in process_class.spec().inputs: + kwargs.pop(port, None) + # If any kwargs remain, the spec should be dynamic, so we raise if it isn't + if kwargs and not process_class.spec().inputs.dynamic: + raise ValueError( + f"{function.__name__} does not support these kwargs: {kwargs.keys()}" + ) + process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid) + # Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner. + # Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown + current_runner = manager.get_runner() + original_handler = None + kill_signal = signal.SIGINT + if not current_runner.is_daemon_runner: + + def kill_process(_num, _frame): + """Send the kill signal to the process in the current scope.""" + LOGGER.critical( + "runner received interrupt, killing process %s", process.pid + ) + result = process.kill( + msg="Process was killed because the runner received an interrupt" + ) + return result + + # Store the current handler on the signal such that it can be restored after process has terminated + original_handler = signal.getsignal(kill_signal) + signal.signal(kill_signal, kill_process) + try: + result = process.execute() + finally: + # If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset + if original_handler: + signal.signal(signal.SIGINT, original_handler) + store_provenance = inputs.get("metadata", {}).get("store_provenance", True) + if not store_provenance: + process.node._storable = False + process.node._unstorable_message = ( + "cannot store node because it was run with `store_provenance=False`" + ) + return result, process.node + + +def submit( + process: TYPE_SUBMIT_PROCESS, + inputs: dict[str, t.Any] | None = None, + *, + wait: bool = False, + wait_interval: int = 5, + parent_pid: int | None = None, + runner: "Runner" | None = None, + **kwargs: t.Any, +) -> ProcessNode: + """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. + + .. warning: this should not be used within another process. Instead, there one should use the ``submit`` method of + the wrapping process itself, i.e. use ``self.submit``. + + .. warning: submission of processes requires ``store_provenance=True``. + + :param process: the process class, instance or builder to submit + :param inputs: the input dictionary to be passed to the process + :param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which + point the function returns the calculation node. + :param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``. + :param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument. + :return: the calculation node of the process + """ + inputs = prepare_inputs(inputs, **kwargs) + + # Submitting from within another process requires ``self.submit``` unless it is a work function, in which case the + # current process in the scope should be an instance of ``FunctionProcess``. + # if is_process_scoped() and not isinstance(Process.current(), FunctionProcess): + # raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead') + + if not runner: + runner = manager.get_manager().get_runner() + assert runner.persister is not None, "runner does not have a persister" + assert runner.controller is not None, "runner does not have a controller" + + process_inited = instantiate_process( + runner, process, parent_pid=parent_pid, **inputs + ) + + # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this + # instead of raising, because in this way the user does not have to change the launcher when testing. The same goes + # for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation. + if process_inited.metadata.get("dry_run", False) or "remote_folder" in inputs: + _, node = run_get_node(process_inited) + return node + + if not process_inited.metadata.store_provenance: + raise InvalidOperation("cannot submit a process with `store_provenance=False`") + + runner.persister.save_checkpoint(process_inited) + process_inited.close() + + # Do not wait for the future's result, because in the case of a single worker this would cock-block itself + runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) + node = process_inited.node + + if not wait: + return node + + while not node.is_terminated: + LOGGER.report( + f"Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. " + f"Waiting for {wait_interval} seconds." + ) + time.sleep(wait_interval) + + return node diff --git a/aiida_workgraph/engine/override.py b/aiida_workgraph/engine/override.py new file mode 100644 index 00000000..020be102 --- /dev/null +++ b/aiida_workgraph/engine/override.py @@ -0,0 +1,71 @@ +from plumpy.process_comms import RemoteProcessThreadController +from typing import Any, Optional + +""" +Note: I modified the the create_daemon_runner function and RemoteProcessThreadController +to include the queue_name argument. + +""" + + +def create_daemon_runner( + manager, queue_name: str = None, loop: Optional["asyncio.AbstractEventLoop"] = None +) -> "Runner": + """Create and return a new daemon runner. + This is used by workers when the daemon is running and in testing. + :param loop: the (optional) asyncio event loop to use + :return: a runner configured to work in the daemon configuration + """ + from plumpy.persistence import LoadSaveContext + from aiida.engine import persistence + from aiida.engine.processes.launcher import ProcessLauncher + from plumpy.communications import convert_to_comm + + runner = manager.create_runner(broker_submit=True, loop=loop) + runner_loop = runner.loop + # Listen for incoming launch requests + task_receiver = ProcessLauncher( + loop=runner_loop, + persister=manager.get_persister(), + load_context=LoadSaveContext(runner=runner), + loader=persistence.get_object_loader(), + ) + + def callback(_comm, msg): + print("Received message: {}".format(msg)) + import asyncio + + asyncio.run(task_receiver(_comm, msg)) + print("task_receiver._continue done") + return True + + assert runner.communicator is not None, "communicator not set for runner" + if queue_name is not None: + print("queue_name: {}".format(queue_name)) + queue = runner.communicator._communicator.task_queue( + queue_name, prefetch_count=1 + ) + # queue.add_task_subscriber(callback) + # important to convert the callback + converted = convert_to_comm(task_receiver, runner.communicator._loop) + queue.add_task_subscriber(converted) + else: + runner.communicator.add_task_subscriber(task_receiver) + return runner + + +class ControllerWithQueueName(RemoteProcessThreadController): + def __init__(self, queue_name: str, **kwargs): + super().__init__(**kwargs) + self.queue_name = queue_name + + def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]: + """ + Send a task to be performed using the communicator + + :param message: the task message + :param no_reply: if True, this call will be fire-and-forget, i.e. no return value + :return: the response from the remote side (if no_reply=False) + """ + queue = self._communicator.task_queue(self.queue_name) + return queue.task_send(message, no_reply=no_reply) diff --git a/aiida_workgraph/engine/scheduler/__init__.py b/aiida_workgraph/engine/scheduler/__init__.py new file mode 100644 index 00000000..95a95abf --- /dev/null +++ b/aiida_workgraph/engine/scheduler/__init__.py @@ -0,0 +1,3 @@ +from .scheduler import WorkGraphScheduler + +__all__ = ("WorkGraphScheduler",) diff --git a/aiida_workgraph/engine/scheduler/client.py b/aiida_workgraph/engine/scheduler/client.py new file mode 100644 index 00000000..4f5a8212 --- /dev/null +++ b/aiida_workgraph/engine/scheduler/client.py @@ -0,0 +1,324 @@ +from aiida.engine.daemon.client import DaemonClient +import shutil +from aiida.manage.manager import get_manager +from aiida.common.exceptions import ConfigurationError +import os +from typing import Optional +from aiida.common.log import AIIDA_LOGGER +from typing import List + +WORKGRAPH_BIN = shutil.which("workgraph") +LOGGER = AIIDA_LOGGER.getChild("engine.launch") + + +class SchedulerClient(DaemonClient): + """Client for interacting with the scheduler daemon.""" + + _DAEMON_NAME = "scheduler-{name}" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def _workgraph_bin(self) -> str: + """Return the absolute path to the ``verdi`` binary. + + :raises ConfigurationError: If the path to ``verdi`` could not be found + """ + if WORKGRAPH_BIN is None: + raise ConfigurationError( + "Unable to find 'verdi' in the path. Make sure that you are working " + "in a virtual environment, or that at least the 'verdi' executable is on the PATH" + ) + + return WORKGRAPH_BIN + + @property + def filepaths(self): + """Return the filepaths used by this profile. + + :return: a dictionary of filepaths + """ + from aiida.manage.configuration.settings import DAEMON_DIR, DAEMON_LOG_DIR + + return { + "circus": { + "log": str( + DAEMON_LOG_DIR / f"circus-scheduler-{self.profile.name}.log" + ), + "pid": str(DAEMON_DIR / f"circus-scheduler-{self.profile.name}.pid"), + "port": str(DAEMON_DIR / f"circus-scheduler-{self.profile.name}.port"), + "socket": { + "file": str( + DAEMON_DIR / f"circus-scheduler-{self.profile.name}.sockets" + ), + "controller": "circus.c.sock", + "pubsub": "circus.p.sock", + "stats": "circus.s.sock", + }, + }, + "daemon": { + "log": str(DAEMON_LOG_DIR / f"aiida-scheduler-{self.profile.name}.log"), + "pid": str(DAEMON_DIR / f"aiida-scheduler-{self.profile.name}.pid"), + }, + } + + @property + def circus_log_file(self) -> str: + return self.filepaths["circus"]["log"] + + @property + def circus_pid_file(self) -> str: + return self.filepaths["circus"]["pid"] + + @property + def circus_port_file(self) -> str: + return self.filepaths["circus"]["port"] + + @property + def circus_socket_file(self) -> str: + return self.filepaths["circus"]["socket"]["file"] + + @property + def circus_socket_endpoints(self) -> dict[str, str]: + return self.filepaths["circus"]["socket"] + + @property + def daemon_log_file(self) -> str: + return self.filepaths["daemon"]["log"] + + @property + def daemon_pid_file(self) -> str: + return self.filepaths["daemon"]["pid"] + + def cmd_start_daemon( + self, number_workers: int = 1, foreground: bool = False + ) -> list[str]: + """Return the command to start the daemon. + + :param number_workers: Number of daemon workers to start. + :param foreground: Whether to launch the subprocess in the background or not. + """ + command = [ + self._workgraph_bin, + "-p", + self.profile.name, + "scheduler", + "start-circus", + str(number_workers), + ] + + if foreground: + command.append("--foreground") + + return command + + @property + def cmd_start_daemon_worker(self) -> list[str]: + """Return the command to start a daemon worker process.""" + return [self._workgraph_bin, "-p", self.profile.name, "scheduler", "worker"] + + def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> None: + """Start the daemon. + + .. warning:: This will daemonize the current process and put it in the background. It is most likely not what + you want to call if you want to start the daemon from the Python API. Instead you probably will want to use + the :meth:`aiida.engine.daemon.client.DaemonClient.start_daemon` function instead. + + :param number_workers: Number of daemon workers to start. + :param foreground: Whether to launch the subprocess in the background or not. + """ + from circus import get_arbiter + from circus import logger as circus_logger + from circus.circusd import daemonize + from circus.pidfile import Pidfile + from circus.util import check_future_exception_and_log, configure_logger + + loglevel = self.loglevel + logoutput = "-" + + if not foreground: + logoutput = self.circus_log_file + + arbiter_config = { + "controller": self.get_controller_endpoint(), + "pubsub_endpoint": self.get_pubsub_endpoint(), + "stats_endpoint": self.get_stats_endpoint(), + "logoutput": logoutput, + "loglevel": loglevel, + "debug": False, + "statsd": True, + "pidfile": self.circus_pid_file, + "watchers": [ + { + "cmd": " ".join(self.cmd_start_daemon_worker), + "name": self.daemon_name, + "numprocesses": number_workers, + "virtualenv": self.virtualenv, + "copy_env": True, + "stdout_stream": { + "class": "FileStream", + "filename": self.daemon_log_file, + }, + "stderr_stream": { + "class": "FileStream", + "filename": self.daemon_log_file, + }, + "env": self.get_env(), + } + ], + } + + if not foreground: + daemonize() + + arbiter = get_arbiter(**arbiter_config) + pidfile = Pidfile(arbiter.pidfile) + pidfile.create(os.getpid()) + + # Configure the logger + loggerconfig = arbiter.loggerconfig or None + configure_logger(circus_logger, loglevel, logoutput, loggerconfig) + + # Main loop + should_restart = True + + while should_restart: + try: + future = arbiter.start() + should_restart = False + if check_future_exception_and_log(future) is None: + should_restart = arbiter._restarting + except Exception as exception: + # Emergency stop + arbiter.loop.run_sync(arbiter._emergency_stop) + raise exception + except KeyboardInterrupt: + pass + finally: + arbiter = None + if pidfile is not None: + pidfile.unlink() + + +def get_scheduler_client(profile_name: Optional[str] = None) -> "SchedulerClient": + """Return the daemon client for the given profile or the current profile if not specified. + + :param profile_name: Optional profile name to load. + :return: The daemon client. + + :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found. + :raises aiida.common.ProfileConfigurationError: if the given profile does not exist. + """ + profile = get_manager().load_profile(profile_name) + return SchedulerClient(profile) + + +def get_scheduler() -> List[int]: + from aiida.orm import QueryBuilder + from aiida_workgraph.engine.scheduler import WorkGraphScheduler + + qb = QueryBuilder() + projections = ["id"] + filters = { + "or": [ + {"attributes.sealed": False}, + {"attributes": {"!has_key": "sealed"}}, + ] + } + qb.append(WorkGraphScheduler, filters=filters, project=projections, tag="process") + results = qb.all() + pks = [r[0] for r in results] + return pks + + +def start_scheduler_worker(foreground: bool = False) -> None: + """Start a scheduler worker for the currently configured profile. + + :param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to + write to the scheduler log file. + """ + import asyncio + import signal + import sys + from aiida_workgraph.engine.scheduler.client import get_scheduler_client + from aiida_workgraph.engine.override import create_daemon_runner + + from aiida.common.log import configure_logging + from aiida.manage import get_config_option + from aiida.engine.daemon.worker import shutdown_worker + + daemon_client = get_scheduler_client() + configure_logging( + daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file + ) + + LOGGER.debug(f"sys.executable: {sys.executable}") + LOGGER.debug(f"sys.path: {sys.path}") + + try: + manager = get_manager() + runner = create_daemon_runner(manager, queue_name="scheduler_queue") + except Exception: + LOGGER.exception("daemon worker failed to start") + raise + + if isinstance(rlimit := get_config_option("daemon.recursion_limit"), int): + LOGGER.info("Setting maximum recursion limit of daemon worker to %s", rlimit) + sys.setrecursionlimit(rlimit) + + signals = (signal.SIGTERM, signal.SIGINT) + for s in signals: + # https://github.com/python/mypy/issues/12557 + runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) # type: ignore[misc] + + try: + LOGGER.info("Starting a daemon worker") + runner.start() + except SystemError as exception: + LOGGER.info("Received a SystemError: %s", exception) + runner.close() + + LOGGER.info("Daemon worker started") + + +def start_scheduler_process(number: int = 1) -> None: + """Start or restart the specified number of scheduler processes.""" + from aiida_workgraph.engine.scheduler import WorkGraphScheduler + from aiida_workgraph.engine.scheduler.client import get_scheduler + from aiida_workgraph.utils.control import create_scheduler_action + from aiida_workgraph.engine.utils import instantiate_process + + try: + schedulers: List[int] = get_scheduler() + existing_schedulers_count = len(schedulers) + print( + "Found {} existing scheduler(s): {}".format( + existing_schedulers_count, " ".join([str(pk) for pk in schedulers]) + ) + ) + + count = 0 + + # Restart existing schedulers if they exceed the number to start + for pk in schedulers[:number]: + # When the runner stop, the runner does not ack back to rmq, + # so the msg is still in the queue, and the msg is not acked, + # we don't need send the msg to continue again + print(f"Scheduler with pk {pk} restart and running.") + count += 1 + # not running + for pk in schedulers[number:]: + print(f"Scheduler with pk {pk} not running.") + + # Start new schedulers if more are needed + runner = get_manager().get_runner() + for i in range(count, number): + process_inited = instantiate_process(runner, WorkGraphScheduler) + process_inited.runner.persister.save_checkpoint(process_inited) + process_inited.close() + create_scheduler_action(process_inited.node.pk) + print(f"Scheduler with pk {process_inited.node.pk} running.") + + except Exception as e: + raise (f"An error occurred while starting schedulers: {e}") diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py new file mode 100644 index 00000000..77ba9ca9 --- /dev/null +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -0,0 +1,1743 @@ +"""AiiDA workflow components: WorkGraph.""" +from __future__ import annotations + +import asyncio +import collections.abc +import functools +import logging +import typing as t + +from plumpy import process_comms +from plumpy.persistence import auto_persist +from plumpy.process_states import Continue, Wait, Finished, Running +import kiwipy + +from aiida.common import exceptions +from aiida.common.extendeddicts import AttributeDict +from aiida.common.lang import override +from aiida import orm +from aiida.orm import load_node, Node, ProcessNode, WorkChainNode +from aiida.orm.utils.serialize import deserialize_unsafe, serialize + +from aiida.engine.processes.exit_code import ExitCode +from aiida.engine.processes.process import Process + +from aiida.engine.processes.workchains.awaitable import ( + Awaitable, + AwaitableAction, + AwaitableTarget, + construct_awaitable, +) +from aiida.engine.processes.workchains.workchain import Protect, WorkChainSpec +from aiida_workgraph.utils import create_and_pause_process +from aiida_workgraph.task import Task +from aiida_workgraph.utils import get_nested_dict, update_nested_dict +from aiida_workgraph.executors.monitors import monitor +from aiida.common.log import LOG_LEVEL_REPORT + +if t.TYPE_CHECKING: + from aiida.engine.runners import Runner # pylint: disable=unused-import + +__all__ = "WorkGraph" + + +MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}. Cannot launch the job: {}." + + +@auto_persist("_awaitables") +class WorkGraphScheduler(Process, metaclass=Protect): + """The `WorkGraph` class is used to construct workflows in AiiDA.""" + + # used to create a process node that represents what happened in this process. + _node_class = WorkChainNode + _spec_class = WorkChainSpec + _CONTEXT = "CONTEXT" + + def __init__( + self, + inputs: dict | None = None, + logger: logging.Logger | None = None, + runner: "Runner" | None = None, + enable_persistence: bool = True, + **kwargs: t.Any, + ) -> None: + """Construct a WorkGraph instance. + + :param inputs: work graph inputs + :param logger: aiida logger + :param runner: work graph runner + :param enable_persistence: whether to persist this work graph + + """ + + super().__init__( + inputs, logger, runner, enable_persistence=enable_persistence, **kwargs + ) + + self._awaitables: list[Awaitable] = [] + self._context = AttributeDict() + + @classmethod + def define(cls, spec: WorkChainSpec) -> None: + super().define(spec) + spec.input("input_file", valid_type=orm.SinglefileData, required=False) + spec.input_namespace( + "wg", dynamic=True, required=False, help="WorkGraph inputs" + ) + spec.input_namespace("input_tasks", dynamic=True, required=False) + spec.exit_code(2, "ERROR_SUBPROCESS", message="A subprocess has failed.") + + spec.outputs.dynamic = True + + spec.output_namespace("new_data", dynamic=True) + spec.output( + "execution_count", + valid_type=orm.Int, + required=False, + help="The number of time the WorkGraph runs.", + ) + # + spec.exit_code( + 201, "UNKNOWN_MESSAGE_TYPE", message="The message type is unknown." + ) + spec.exit_code(202, "UNKNOWN_TASK_TYPE", message="The task type is unknown.") + # + spec.exit_code( + 301, + "OUTPUS_NOT_MATCH_RESULTS", + message="The outputs of the process do not match the results.", + ) + spec.exit_code( + 302, + "TASK_FAILED", + message="Some of the tasks failed.", + ) + spec.exit_code( + 303, + "TASK_NON_ZERO_EXIT_STATUS", + message="Some of the tasks exited with non-zero status.", + ) + + @property + def ctx(self) -> AttributeDict: + """Get the context.""" + return self._context + + @override + def save_instance_state( + self, out_state: t.Dict[str, t.Any], save_context: t.Any + ) -> None: + """Save instance state. + + :param out_state: state to save in + + :param save_context: + :type save_context: :class:`!plumpy.persistence.LoadSaveContext` + + """ + super().save_instance_state(out_state, save_context) + # Save the context + out_state[self._CONTEXT] = self.ctx + + @override + def load_instance_state( + self, saved_state: t.Dict[str, t.Any], load_context: t.Any + ) -> None: + super().load_instance_state(saved_state, load_context) + # Load the context + self._context = saved_state[self._CONTEXT] + self._temp = {"awaitables": {}} + + self.set_logger(self.node.logger) + self.add_workgraph_subsriber() + + if self._awaitables: + # For the "ascyncio.tasks.Task" awaitable, because there are only in-memory, + # we need to reset the tasks and so that they can be re-run again. + should_resume = False + for awaitable in self._awaitables: + if awaitable.target == "asyncio.tasks.Task": + self._resolve_awaitable(awaitable, None) + self.report(f"reset awaitable task: {awaitable.key}") + self.reset_task(awaitable.key) + should_resume = True + if should_resume: + self._update_process_status() + self.resume() + # For other awaitables, because they exist in the db, we only need to re-register the callbacks + self._action_awaitables() + # load checkpoint + launched_workgraphs = self.node.base.extras.get("_launched_workgraphs", []) + for pk in launched_workgraphs: + print("load workgraph: ", pk) + node = load_node(pk) + wgdata = node.base.extras.get("_checkpoint", None) + if wgdata is None: + self.launch_workgraph(pk) + else: + self.ctx._workgraph[pk] = deserialize_unsafe(wgdata) + print("continue workgraph: ", pk) + self.continue_workgraph(pk) + + def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: + """ + Returns a reference to a sub-dictionary of the context and the last key, + after resolving a potentially segmented key where required sub-dictionaries are created as needed. + + :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary + """ + ctx = self.ctx + ctx_path = key.split(".") + + for index, path in enumerate(ctx_path[:-1]): + try: + ctx = ctx[path] + except KeyError: # see below why this is the only exception we have to catch here + ctx[ + path + ] = AttributeDict() # create the sub-dict and update the context + ctx = ctx[path] + continue + + # Notes: + # * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking + # * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables + # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself + # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable + # would be an AttributeDict we can append things to it since the order of tasks is maintained. + if type(ctx) != AttributeDict: # pylint: disable=C0123 + raise ValueError( + f"Can not update the context for key `{key}`: " + f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index + 1])}`, expected AttributeDict' + ) + + return ctx, ctx_path[-1] + + def _insert_awaitable(self, awaitable: Awaitable) -> None: + """Insert an awaitable that should be terminated before before continuing to the next step. + + :param awaitable: the thing to await + """ + ctx, key = self._resolve_nested_context(awaitable.key) + + # Already assign the awaitable itself to the location in the context container where it is supposed to end up + # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the + # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the + # awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value. + if awaitable.action == AwaitableAction.ASSIGN: + ctx[key] = awaitable + elif awaitable.action == AwaitableAction.APPEND: + ctx.setdefault(key, []).append(awaitable) + else: + raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") + + # Register the callback to be called when the awaitable is resolved + self._add_callback_to_awaitable(awaitable) + self._awaitables.append( + awaitable + ) # add only if everything went ok, otherwise we end up in an inconsistent state + self._update_process_status() + + def _resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None: + """Resolve an awaitable. + + Precondition: must be an awaitable that was previously inserted. + + :param awaitable: the awaitable to resolve + :param value: the value to assign to the awaitable + """ + ctx, key = self._resolve_nested_context(awaitable.key) + + if awaitable.action == AwaitableAction.ASSIGN: + ctx[key] = value + elif awaitable.action == AwaitableAction.APPEND: + # Find the same awaitable inserted in the context + container = ctx[key] + for index, placeholder in enumerate(container): + if ( + isinstance(placeholder, Awaitable) + and placeholder.pk == awaitable.pk + ): + container[index] = value + break + else: + raise AssertionError( + f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`" + ) + else: + raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") + + awaitable.resolved = True + # remove awaitabble from the list + self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk] + + if not self.has_terminated(): + # the process may be terminated, for example, if the process was killed or excepted + # then we should not try to update it + self._update_process_status() + + @Protect.final + def to_context(self, **kwargs: Awaitable | ProcessNode) -> None: + """Add a dictionary of awaitables to the context. + + This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will + assign a certain value to the corresponding key in the context of the work graph. + """ + for key, value in kwargs.items(): + awaitable = construct_awaitable(value) + awaitable.key = key + awaitable.workgraph_pk = value.workgraph_pk + self._insert_awaitable(awaitable) + + def _update_process_status(self) -> None: + """Set the process status with a message accounting the current sub processes that we are waiting for.""" + if self._awaitables: + status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}" + self.node.set_process_status(status) + else: + self.node.set_process_status(None) + + @override + def run(self) -> t.Any: + self.setup() + return self._do_step() + + def _do_step(self) -> t.Any: + """Execute the next step in the workgraph and return the result. + + If any awaitables were created, the process will enter in the Wait state, + otherwise it will go to Continue. + """ + # we will not remove the awaitables here, + # we resume the workgraph in the callback function even + # there are some awaitables left + # self._awaitables = [] + + if self._awaitables: + return Wait(self._do_step, "Waiting before next step") + + return Continue(self._do_step) + + def _store_nodes(self, data: t.Any) -> None: + """Recurse through a data structure and store any unstored nodes that are found along the way + + :param data: a data structure potentially containing unstored nodes + """ + if isinstance(data, Node) and not data.is_stored: + data.store() + elif isinstance(data, collections.abc.Mapping): + for _, value in data.items(): + self._store_nodes(value) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str): + for value in data: + self._store_nodes(value) + + @override + @Protect.final + def on_exiting(self) -> None: + """Ensure that any unstored nodes in the context are stored, before the state is exited + + After the state is exited the next state will be entered and if persistence is enabled, a checkpoint will + be saved. If the context contains unstored nodes, the serialization necessary for checkpointing will fail. + """ + super().on_exiting() + try: + self._store_nodes(self.ctx) + except Exception: # pylint: disable=broad-except + # An uncaught exception here will have bizarre and disastrous consequences + self.logger.exception("exception in _store_nodes called in on_exiting") + + @Protect.final + def on_wait(self, awaitables: t.Sequence[t.Awaitable]): + """Entering the WAITING state.""" + super().on_wait(awaitables) + if self._awaitables: + self._action_awaitables() + self.report("Process status: {}".format(self.node.process_status)) + else: + self.call_soon(self.resume) + + def _action_awaitables(self) -> None: + """Handle the awaitables that are currently registered with the work chain. + + Depending on the class type of the awaitable's target a different callback + function will be bound with the awaitable and the runner will be asked to + call it when the target is completed + """ + for awaitable in self._awaitables: + pk = awaitable.workgraph_pk + # if the waitable already has a callback, skip + if awaitable.pk in self.ctx._workgraph[pk]["_awaitable_actions"]: + continue + self._add_callback_to_awaitable(awaitable) + + def _add_callback_to_awaitable(self, awaitable: Awaitable) -> None: + """Add a callback to the awaitable.""" + pk = awaitable.workgraph_pk + if awaitable.target == AwaitableTarget.PROCESS: + callback = functools.partial( + self.call_soon, self._on_awaitable_finished, awaitable + ) + self.runner.call_on_process_finish(awaitable.pk, callback) + self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) + elif awaitable.target == "asyncio.tasks.Task": + # this is a awaitable task, the callback function is already set + self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) + else: + assert f"invalid awaitable target '{awaitable.target}'" + + def _on_awaitable_finished(self, awaitable: Awaitable) -> None: + """Callback function, for when an awaitable process instance is completed. + + The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all + awaitables have been dealt with, the work chain process is resumed. + + :param awaitable: an Awaitable instance + """ + import time + + print(f"Awaitable {awaitable.key} finished.") + self.logger.debug(f"Awaitable {awaitable.key} finished.") + pk = awaitable.workgraph_pk + node = load_node(awaitable.pk) + # there is a bug in aiida.engine.process, it send the msg before setting the process state and outputs + # so we need to wait for a while + # TODO make a PR to fix the aiida-core, the `super().on_entered()` should be + # called after setting the process state + tstart = time.time() + while not node.is_finished and time.time() - tstart < 5: + time.sleep(0.1) + + if isinstance(awaitable.pk, int): + self.logger.info( + "received callback that awaitable with key {} and pk {} has terminated".format( + awaitable.key, awaitable.pk + ) + ) + try: + node = load_node(awaitable.pk) + except (exceptions.MultipleObjectsError, exceptions.NotExistent): + raise ValueError( + f"provided pk<{awaitable.pk}> could not be resolved to a valid Node instance" + ) + + if awaitable.outputs: + value = { + entry.link_label: entry.node + for entry in node.base.links.get_outgoing() + } + else: + value = node # type: ignore + else: + # In this case, the pk and key are the same. + self.logger.info( + "received callback that awaitable {} has terminated".format( + awaitable.key + ) + ) + try: + # if awaitable is cancelled, the result is None + if awaitable.cancelled(): + self.set_task_state_info(pk, awaitable.key, "state", "KILLED") + # set child tasks state to SKIPPED + self.set_tasks_state( + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][ + awaitable.key + ], + "SKIPPED", + ) + self.report(f"Task: {awaitable.key} cancelled.") + else: + results = awaitable.result() + self.set_normal_task_results( + awaitable.workgraph_pk, awaitable.key, results + ) + except Exception as e: + self.logger.error(f"Error in awaitable {awaitable.key}: {e}") + self.set_task_state_info(pk, awaitable.key, "state", "FAILED") + # set child tasks state to SKIPPED + self.set_tasks_state( + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][ + awaitable.key + ], + "SKIPPED", + ) + self.report(f"Task: {awaitable.key} failed.") + self.run_error_handlers(pk, awaitable.key) + value = None + + self._resolve_awaitable(awaitable, value) + + # node finished, update the task state and result + # udpate the task state + print(f"Update task state: {awaitable.key}") + self.update_task_state(awaitable.workgraph_pk, awaitable.key) + # try to resume the workgraph, if the workgraph is already resumed + # by other awaitable, this will not work + # try: + # self.resume() + # except Exception as e: + # print(e) + + def _build_process_label(self) -> str: + """Use the workgraph name as the process label.""" + return "Scheduler" + + def on_create(self) -> None: + """Called when a Process is created.""" + + super().on_create() + self.node.label = "Scheduler" + + def setup(self) -> None: + """Setup the variables in the context.""" + # track if the awaitable callback is added to the runner + + self.ctx.launched_workgraphs = [] + self.ctx._workgraph = {} + self.ctx._max_number_awaitables = 10000 + awaitable = Awaitable( + **{ + "workgraph_pk": self.node.pk, + "pk": "scheduler", + "action": AwaitableAction.ASSIGN, + "target": "scheduler", + "outputs": False, + } + ) + self.ctx._workgraph[self.node.pk] = {"_awaitable_actions": []} + self.to_context(scheduler=awaitable) + # self.ctx._msgs = [] + # self.ctx._workgraph[pk]["_execution_count"] = {} + # data not to be persisted, because they are not serializable + self._temp = {"awaitables": {}} + + def launch_workgraph(self, pk: str) -> None: + """Launch the workgraph.""" + # create the workgraph process + self.report(f"Launch workgraph: {pk}") + # append the pk to the self.node.base.extras + self.ctx.launched_workgraphs.append(pk) + self.node.base.extras.set("_launched_workgraphs", self.ctx.launched_workgraphs) + self.init_ctx_workgraph(pk) + self.ctx._workgraph[pk]["_node"].set_process_state(Running.LABEL) + self.init_task_results(pk) + self.continue_workgraph(pk) + + def init_ctx_workgraph(self, pk: int) -> None: + """Init the context from the workgraph data.""" + from aiida_workgraph.utils import update_nested_dict + + # read the latest workgraph data + wgdata, node = self.read_wgdata_from_base(pk) + self.ctx._workgraph[pk] = { + "_awaitable_actions": {}, + "_new_data": {}, + "_execution_count": 1, + "_executed_tasks": [], + "_count": 0, + "_context": {}, + "_node": node, + } + for key, value in wgdata["context"].items(): + key = key.replace("__", ".") + update_nested_dict(self.ctx._workgraph[pk], key, value) + # set up the workgraph + self.setup_ctx_workgraph(pk, wgdata) + + def setup_ctx_workgraph(self, pk: int, wgdata: t.Dict[str, t.Any]) -> None: + """setup the workgraph in the context.""" + import cloudpickle as pickle + + self.ctx._workgraph[pk]["_tasks"] = wgdata.pop("tasks") + self.ctx._workgraph[pk]["_links"] = wgdata.pop("links") + self.ctx._workgraph[pk]["_connectivity"] = wgdata.pop("connectivity") + self.ctx._workgraph[pk]["_ctrl_links"] = wgdata.pop("ctrl_links") + self.ctx._workgraph[pk]["_error_handlers"] = pickle.loads( + wgdata.pop("error_handlers") + ) + self.ctx._workgraph[pk]["_metadata"] = wgdata.pop("metadata") + self.ctx._workgraph[pk]["_workgraph"] = wgdata + self.ctx._workgraph[pk]["_awaitable_actions"] = [] + + def read_wgdata_from_base(self, pk: int) -> t.Dict[str, t.Any]: + """Read workgraph data from base.extras.""" + from aiida_workgraph.orm.function_data import PickledLocalFunction + + node = load_node(pk) + + wgdata = node.base.extras.get("_workgraph") + for name, task in wgdata["tasks"].items(): + wgdata["tasks"][name] = deserialize_unsafe(task) + for _, prop in wgdata["tasks"][name]["properties"].items(): + if isinstance(prop["value"], PickledLocalFunction): + prop["value"] = prop["value"].value + wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) + wgdata["context"] = deserialize_unsafe(wgdata["context"]) + + return wgdata, node + + def update_workgraph_from_base(self, pk: int) -> None: + """Update the ctx from base.extras.""" + wgdata, _ = self.read_wgdata_from_base() + for name, task in wgdata["tasks"].items(): + task["results"] = self.ctx._workgraph[pk]["_tasks"][name].get("results") + self.setup_ctx_workgraph(pk, wgdata) + + def get_task(self, name: str): + """Get task from the context.""" + task = Task.from_dict(self.ctx._workgraph[pk]["_tasks"][name]) + return task + + def update_task(self, pk, task: Task): + """Update task in the context. + This is used in error handlers to update the task parameters.""" + self.ctx._workgraph[pk]["_tasks"][task.name][ + "properties" + ] = task.properties_to_dict() + self.reset_task(task.name) + + def get_task_state_info(self, pk: int, name: str, key: str) -> str: + """Get task state info from ctx.""" + + value = self.ctx._workgraph[pk]["_tasks"][name].get(key, None) + if key == "process" and value is not None: + value = deserialize_unsafe(value) + return value + + def set_task_state_info(self, pk: int, name: str, key: str, value: any) -> None: + """Set task state info to ctx and base.extras. + We task state to the base.extras, so that we can access outside the engine""" + + if key == "process": + value = serialize(value) + self.ctx._workgraph[pk]["_node"].base.extras.set( + f"_task_{key}_{name}", value + ) + else: + self.ctx._workgraph[pk]["_node"].base.extras.set( + f"_task_{key}_{name}", value + ) + self.ctx._workgraph[pk]["_tasks"][name][key] = value + + def init_task_results(self, pk) -> None: + """Init the task results.""" + for name, task in self.ctx._workgraph[pk]["_tasks"].items(): + if self.get_task_state_info(pk, name, "action").upper() == "RESET": + self.reset_task(pk, task["name"]) + # only init the task results, and do not need to continue the workgraph + self.update_task_state(pk, name, continue_workgraph=False) + + def apply_action(self, msg: dict) -> None: + + if msg["catalog"] == "task": + self.apply_task_actions(msg) + else: + self.report(f"Unknow message type {msg}") + + def apply_task_actions(self, msg: dict) -> None: + """Apply task actions to the workgraph.""" + action = msg["action"] + tasks = msg["tasks"] + self.report(f"Action: {action}. {tasks}") + if action.upper() == "RESET": + for name in tasks: + self.reset_task(name) + elif action.upper() == "PAUSE": + for name in tasks: + self.pause_task(name) + elif action.upper() == "PLAY": + for name in tasks: + self.play_task(name) + elif action.upper() == "SKIP": + for name in tasks: + self.skip_task(name) + elif action.upper() == "KILL": + for name in tasks: + self.kill_task(name) + elif action.upper() == "LAUNCH_WORKGRAPH": + for pk in tasks: + self.launch_workgraph(pk) + + def reset_task( + self, + pk: int, + name: str, + reset_process: bool = True, + recursive: bool = True, + reset_execution_count: bool = True, + ) -> None: + """Reset task state and remove it from the executed task. + If recursive is True, reset its child tasks.""" + + self.set_task_state_info(pk, name, "state", "PLANNED") + if reset_process: + self.set_task_state_info(pk, name, "process", None) + self.remove_executed_task(name) + # self.logger.debug(f"Task {name} action: RESET.") + # if the task is a while task, reset its child tasks + if ( + self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["node_type"].upper() + == "WHILE" + ): + if reset_execution_count: + self.ctx._workgraph[pk]["_tasks"][name]["execution_count"] = 0 + for child_task in self.ctx._workgraph[pk]["_tasks"][name]["children"]: + self.reset_task(child_task, reset_process=False, recursive=False) + if recursive: + # reset its child tasks + names = self.ctx._workgraph[pk]["_connectivity"]["child_node"][name] + for name in names: + self.reset_task(name, recursive=False) + + def pause_task(self, name: str) -> None: + """Pause task.""" + self.set_task_state_info(pk, name, "action", "PAUSE") + self.report(f"Task {name} action: PAUSE.") + + def play_task(self, name: str) -> None: + """Play task.""" + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} action: PLAY.") + + def skip_task(self, name: str) -> None: + """Skip task.""" + self.set_task_state_info(pk, name, "state", "SKIPPED") + self.report(f"Task {name} action: SKIP.") + + def kill_task(self, pk, name: str) -> None: + """Kill task. + This is used to kill the awaitable and monitor task. + """ + if self.get_task_state_info(pk, name, "state") in ["RUNNING"]: + if self.ctx._workgraph[pk]["_tasks"][name]["metadata"][ + "node_type" + ].upper() in [ + "AWAITABLE", + "MONITOR", + ]: + try: + self._temp["awaitables"][name].cancel() + self.set_task_state_info(pk, name, "state", "KILLED") + self.report(f"Task {name} action: KILLED.") + except Exception as e: + self.logger.error(f"Error in killing task {name}: {e}") + + def report(self, msg, pk=None): + """Report the message.""" + if pk: + self.ctx._workgraph[pk]["_node"].logger.log(LOG_LEVEL_REPORT, msg) + else: + super().report(msg) + + def continue_workgraph(self, pk: int) -> None: + # if the workgraph is finished, skip + if pk not in self.ctx._workgraph: + # the workgraph is finished + return + is_finished, _ = self.is_workgraph_finished(pk) + if is_finished: + self.finalize_workgraph(pk) + # remove the workgraph from the context + del self.ctx._workgraph[pk] + self.ctx.launched_workgraphs.remove(pk) + self.node.base.extras.set( + "_launched_workgraphs", self.ctx.launched_workgraphs + ) + return + self.report("Continue.", pk) + task_to_run = [] + for name, task in self.ctx._workgraph[pk]["_tasks"].items(): + # update task state + if ( + self.get_task_state_info(pk, task["name"], "state") + in [ + "CREATED", + "RUNNING", + "FINISHED", + "FAILED", + "SKIPPED", + ] + or name in self.ctx._workgraph[pk]["_executed_tasks"] + ): + continue + ready, _ = self.is_task_ready_to_run(pk, name) + if ready: + task_to_run.append(name) + # + self.report("tasks ready to run: {}".format(",".join(task_to_run)), pk) + if len(task_to_run) > 0: + self.run_tasks(pk, task_to_run) + + def update_task_state( + self, pk: int, name: str, continue_workgraph: bool = True + ) -> None: + """Update task state when the task is finished.""" + + task = self.ctx._workgraph[pk]["_tasks"][name] + # print(f"set task result: {name}") + node = self.get_task_state_info(pk, name, "process") + if isinstance(node, orm.ProcessNode): + # print(f"set task result: {name} process") + state = node.process_state.value.upper() + if node.is_finished_ok: + self.set_task_state_info(pk, task["name"], "state", state) + if task["metadata"]["node_type"].upper() == "WORKGRAPH": + # expose the outputs of all the tasks in the workgraph + task["results"] = {} + outgoing = node.base.links.get_outgoing() + for link in outgoing.all(): + if isinstance(link.node, ProcessNode) and getattr( + link.node, "process_state", False + ): + task["results"][link.link_label] = link.node.outputs + else: + task["results"] = node.outputs + # self.ctx._workgraph[pk]["_new_data"][name] = task["results"] + self.set_task_state_info(pk, task["name"], "state", "FINISHED") + self.task_set_context(pk, name) + self.report(f"Task: {name} finished.", pk=pk) + # all other states are considered as failed + else: + print(f"set task result: {name} failed") + task["results"] = node.outputs + # self.ctx._workgraph[pk]["_new_data"][name] = task["results"] + self.set_task_state_info(pk, task["name"], "state", "FAILED") + # set child tasks state to SKIPPED + self.set_tasks_state( + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", + ) + self.report(f"Task: {name} failed.", pk) + self.run_error_handlers(pk, name) + elif isinstance(node, orm.Data): + task["results"] = {task["outputs"][0]["name"]: node} + self.set_task_state_info(pk, task["name"], "state", "FINISHED") + self.task_set_context(pk, name) + self.report(f"Task: {name} finished.", pk) + else: + task.setdefault("results", None) + + self.update_parent_task_state(pk, name) + self.save_workgraph_checkpoint(pk) + if continue_workgraph: + try: + self.continue_workgraph(pk) + except Exception as e: + print(e) + + def set_normal_task_results(self, pk, name, results): + """Set the results of a normal task. + A normal task is created by decorating a function with @task(). + """ + task = self.ctx._workgraph[pk]["_tasks"][name] + if isinstance(results, tuple): + if len(task["outputs"]) != len(results): + return self.exit_codes.OUTPUS_NOT_MATCH_RESULTS + for i in range(len(task["outputs"])): + task["results"][task["outputs"][i]["name"]] = results[i] + elif isinstance(results, dict): + task["results"] = results + else: + task["results"][task["outputs"][0]["name"]] = results + self.task_set_context(pk, name) + self.set_task_state_info(pk, name, "state", "FINISHED") + self.report(f"Task: {name} finished.", pk=pk) + self.update_parent_task_state(pk, name) + + def save_workgraph_checkpoint(self, pk: int): + """Save the workgraph checkpoint.""" + self.ctx._workgraph[pk]["_node"].base.extras.set( + "_checkpoint", serialize(self.ctx._workgraph[pk]) + ) + + def update_parent_task_state(self, pk, name: str) -> None: + """Update parent task state.""" + parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"] + if parent_task[0]: + task_type = self.ctx._workgraph[pk]["_tasks"][parent_task[0]]["metadata"][ + "node_type" + ].upper() + if task_type == "WHILE": + self.update_while_task_state(pk, parent_task[0]) + elif task_type == "IF": + self.update_zone_task_state(pk, parent_task[0]) + elif task_type == "ZONE": + self.update_zone_task_state(pk, parent_task[0]) + + def update_while_task_state(self, pk: int, name: str) -> None: + """Update while task state.""" + finished, _ = self.are_childen_finished(name) + + if finished: + self.report( + f"Wihle Task {name}: this iteration finished. Try to reset for the next iteration.", + pk, + ) + # reset the condition tasks + for input in self.ctx._workgraph[pk]["_tasks"][name]["inputs"]: + if input["name"].upper() == "CONDITIONS": + for link in input["links"]: + self.reset_task(link["from_node"], recursive=False) + # reset the task and all its children, so that the task can run again + # do not reset the execution count + self.reset_task(name, reset_execution_count=False) + + def update_zone_task_state(self, pk: int, name: str) -> None: + """Update zone task state.""" + finished, _ = self.are_childen_finished(name) + if finished: + self.set_task_state_info(pk, name, "state", "FINISHED") + self.update_parent_task_state(pk, name) + self.report(f"Task: {name} finished.", pk) + + def should_run_while_task(self, pk: int, name: str) -> tuple[bool, t.Any]: + """Check if the while task should run.""" + # check the conditions of the while task + not_excess_max_iterations = ( + self.ctx._workgraph[pk]["_tasks"][name]["execution_count"] + < self.ctx._workgraph[pk]["_tasks"][name]["properties"]["max_iterations"][ + "value" + ] + ) + conditions = [not_excess_max_iterations] + _, kwargs, _, _, _ = self.get_inputs(pk, name) + if isinstance(kwargs["conditions"], list): + for condition in kwargs["conditions"]: + value = get_nested_dict(self.ctx, condition) + conditions.append(value) + elif isinstance(kwargs["conditions"], dict): + for _, value in kwargs["conditions"].items(): + conditions.append(value) + else: + conditions.append(kwargs["conditions"]) + return False not in conditions + + def should_run_if_task(self, name: str) -> tuple[bool, t.Any]: + """Check if the IF task should run.""" + _, kwargs, _, _, _ = self.get_inputs(pk, name) + flag = kwargs["conditions"] + if kwargs["invert_condition"]: + return not flag + return flag + + def are_childen_finished(self, pk, name: str) -> tuple[bool, t.Any]: + """Check if the child tasks are finished.""" + task = self.ctx._workgraph[pk]["_tasks"][name] + finished = True + for name in task["children"]: + if self.get_task_state_info(pk, name, "state") not in [ + "FINISHED", + "SKIPPED", + "FAILED", + ]: + finished = False + break + return finished, None + + def run_error_handlers(self, pk: int, task_name: str) -> None: + """Run error handler.""" + node = self.get_task_state_info(pk, task_name, "process") + if not node or not node.exit_status: + return + for _, data in self.ctx._workgraph[pk]["_error_handlers"].items(): + if task_name in data["tasks"]: + handler = data["handler"] + metadata = data["tasks"][task_name] + if node.exit_code.status in metadata.get("exit_codes", []): + self.report(f"Run error handler: {metadata}", pk) + metadata.setdefault("retry", 0) + if metadata["retry"] < metadata["max_retries"]: + handler(self, task_name, **metadata.get("kwargs", {})) + metadata["retry"] += 1 + + def is_workgraph_finished(self, pk) -> bool: + """Check if the workgraph is finished. + For `while` workgraph, we need check its conditions""" + is_finished = True + failed_tasks = [] + for name, task in self.ctx._workgraph[pk]["_tasks"].items(): + # self.update_task_state(pk, name) + if self.get_task_state_info(pk, task["name"], "state") in [ + "RUNNING", + "CREATED", + "PLANNED", + "READY", + ]: + is_finished = False + elif self.get_task_state_info(pk, task["name"], "state") == "FAILED": + failed_tasks.append(name) + if is_finished: + if ( + self.ctx._workgraph[pk]["_workgraph"]["workgraph_type"].upper() + == "WHILE" + ): + should_run = self.check_while_conditions(pk) + is_finished = not should_run + if self.ctx._workgraph[pk]["_workgraph"]["workgraph_type"].upper() == "FOR": + should_run = self.check_for_conditions(pk) + is_finished = not should_run + if is_finished and len(failed_tasks) > 0: + message = f"WorkGraph finished, but tasks: {failed_tasks} failed. Thus all their child tasks are skipped." + self.report(message, pk) + result = ExitCode(302, message) + else: + result = None + return is_finished, result + + def check_while_conditions(self, pk: int) -> bool: + """Check while conditions. + Run all condition tasks and check if all the conditions are True. + """ + if ( + self.ctx._workgraph[pk]["_execution_count"] + >= self.ctx._workgraph[pk]["_max_iteration"] + ): + self.report("Max iteration reached.", pk) + return False + condition_tasks = [] + for c in self.ctx._workgraph[pk]["conditions"]: + task_name, socket_name = c.split(".") + if "task_name" != "context": + condition_tasks.append(task_name) + self.run_tasks(condition_tasks, continue_workgraph=False) + conditions = [] + for c in self.ctx._workgraph[pk]["conditions"]: + task_name, socket_name = c.split(".") + if task_name == "context": + conditions.append(self.ctx[socket_name]) + else: + conditions.append( + self.ctx._workgraph[pk]["_tasks"][task_name]["results"][socket_name] + ) + should_run = False not in conditions + if should_run: + self.reset_workgraph(pk) + self.set_tasks_state(pk, condition_tasks, "SKIPPED") + return should_run + + def check_for_conditions(self, pk: int) -> bool: + condition_tasks = [c[0] for c in self.ctx._workgraph[pk]["conditions"]] + self.run_tasks(condition_tasks) + conditions = [self.ctx._count < len(self.ctx._sequence)] + [ + self.ctx._workgraph[pk]["_tasks"][c[0]]["results"][c[1]] + for c in self.ctx._workgraph[pk]["conditions"] + ] + should_run = False not in conditions + if should_run: + self.reset_workgraph(pk) + self.set_tasks_state(pk, condition_tasks, "SKIPPED") + self.ctx["i"] = self.ctx._sequence[self.ctx._count] + self.ctx._count += 1 + return should_run + + def remove_executed_task(self, pk, name: str) -> None: + """Remove labels with name from executed tasks.""" + self.ctx._workgraph[pk]["_executed_tasks"] = [ + label + for label in self.ctx._workgraph[pk]["_executed_tasks"] + if label.split(".")[0] != name + ] + + def add_task_link(self, pk, node: ProcessNode) -> None: + from aiida.common.links import LinkType + + parent_calc = self.ctx._workgraph[pk]["_node"] + if isinstance(node, orm.CalculationNode): + node.base.links.add_incoming( + parent_calc, LinkType.CALL_CALC, "CALL" + ) # TODO, self.metadata.call_link_label) + elif isinstance(node, orm.WorkflowNode): + node.base.links.add_incoming( + parent_calc, LinkType.CALL_WORK, "CALL" + ) # TODO, self.metadata.call_link_label) + + def run_tasks( + self, pk: int, names: t.List[str], continue_workgraph: bool = True + ) -> None: + """Run tasks. + Task type includes: Node, Data, CalcFunction, WorkFunction, CalcJob, WorkChain, GraphBuilder, + WorkGraph, PythonJob, ShellJob, While, If, Zone, FromContext, ToContext, Normal. + + Here we use ToContext to pass the results of the run to the next step. + This will force the engine to wait for all the submitted processes to + finish before continuing to the next step. + """ + from aiida_workgraph.utils import ( + get_executor, + create_data_node, + update_nested_dict_with_special_keys, + ) + from aiida_workgraph.engine.workgraph import WorkGraphEngine + from aiida_workgraph.engine import launch + + for name in names: + # skip if the max number of awaitables is reached + task = self.ctx._workgraph[pk]["_tasks"][name] + if task["metadata"]["node_type"].upper() in [ + "CALCJOB", + "WORKCHAIN", + "GRAPH_BUILDER", + "WORKGRAPH", + "PYTHONJOB", + "SHELLJOB", + ]: + if len(self._awaitables) >= self.ctx._max_number_awaitables: + print( + MAX_NUMBER_AWAITABLES_MSG.format( + self.ctx._max_number_awaitables, name + ) + ) + continue + # skip if the task is already executed + if name in self.ctx._workgraph[pk]["_executed_tasks"]: + continue + self.ctx._workgraph[pk]["_executed_tasks"].append(name) + print("-" * 60) + + self.report(f"Run task: {name}, type: {task['metadata']['node_type']}", pk) + executor, _ = get_executor(task["executor"]) + # print("executor: ", executor) + args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(pk, name) + for i, key in enumerate( + self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"] + ): + kwargs[key] = args[i] + # update the port namespace + kwargs = update_nested_dict_with_special_keys(kwargs) + # print("args: ", args) + # print("kwargs: ", kwargs) + # print("var_kwargs: ", var_kwargs) + # kwargs["meta.label"] = name + # output must be a Data type or a mapping of {string: Data} + task["results"] = {} + if task["metadata"]["node_type"].upper() == "NODE": + results = self.run_executor(executor, [], kwargs, var_args, var_kwargs) + self.set_task_state_info(pk, name, "process", results) + self.update_task_state(pk, name) + if continue_workgraph: + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() == "DATA": + for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]: + kwargs.pop(key, None) + results = create_data_node(executor, args, kwargs) + self.set_task_state_info(pk, name, "process", results) + self.update_task_state(pk, name) + self.ctx._workgraph[pk]["_new_data"][name] = results + if continue_workgraph: + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in [ + "CALCFUNCTION", + "WORKFUNCTION", + ]: + kwargs.setdefault("metadata", {}) + kwargs["metadata"].update({"call_link_label": name}) + kwargs["parent_pid"] = pk + try: + # since aiida 2.5.0, we need to use args_dict to pass the args to the run_get_node + if var_kwargs is None: + results, process = launch.run_get_node( + executor.process_class, **kwargs + ) + else: + results, process = launch.run_get_node( + executor.process_class, **kwargs, **var_kwargs + ) + process.label = name + # print("results: ", results) + self.set_task_state_info(pk, name, "process", process) + self.update_task_state(pk, name) + except Exception as e: + self.logger.error(f"Error in task {name}: {e}") + self.set_task_state_info(pk, name, "state", "FAILED") + # set child state to FAILED + self.set_tasks_state( + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", + ) + self.report(f"Task: {name} failed.", pk) + # exclude the current tasks from the next run + if continue_workgraph: + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["CALCJOB", "WORKCHAIN"]: + # process = run_get_node(executor, *args, **kwargs) + kwargs.setdefault("metadata", {}) + kwargs["metadata"].update({"call_link_label": name}) + kwargs["parent_pid"] = pk + try: + # transfer the args to kwargs + if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} is created and paused.", pk) + process = create_and_pause_process( + self.runner, + executor, + kwargs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(pk, name, "state", "CREATED") + process = process.node + else: + process = launch.submit(executor, **kwargs) + self.set_task_state_info(pk, name, "state", "RUNNING") + process.label = name + process.workgraph_pk = pk + self.set_task_state_info(pk, name, "process", process) + self.to_context(**{name: process}) + except Exception as e: + self.set_task_state_info(pk, name, "state", "FAILED") + # set child state to FAILED + self.set_tasks_state( + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", + ) + self.report(f"Error in task {name}: {e}", pk) + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]: + from aiida_workgraph.utils.control import create_workgraph_action + + wg = self.run_executor(executor, [], kwargs, var_args, var_kwargs) + wg.name = name + wg.group_outputs = self.ctx._workgraph[pk]["_tasks"][name]["metadata"][ + "group_outputs" + ] + metadata = {"call_link_label": name} + try: + wg.save(metadata=metadata, parent_pid=pk) + process = wg.process + create_workgraph_action(process.pk) + process.workgraph_pk = pk + self.set_task_state_info(pk, name, "process", process) + self.set_task_state_info(pk, name, "state", "RUNNING") + self.report(f"GraphBuilder {name} is created.", pk) + self.to_context(**{name: process}) + except Exception as e: + print("Error in launching the workgraph: ", e) + elif task["metadata"]["node_type"].upper() in ["WORKGRAPH"]: + from .utils import prepare_for_workgraph_task + + inputs, _ = prepare_for_workgraph_task(task, kwargs) + process = launch.submit(WorkGraphEngine, inputs=inputs, parent_pid=pk) + process.workgraph_pk = pk + self.set_task_state_info(pk, name, "process", process) + self.set_task_state_info(pk, name, "state", "RUNNING") + self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["PYTHONJOB"]: + from aiida_workgraph.calculations.python import PythonJob + from .utils import prepare_for_python_task + + inputs = prepare_for_python_task(task, kwargs, var_kwargs) + # since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs + if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} is created and paused.", pk) + process = create_and_pause_process( + self.runner, + PythonJob, + inputs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(pk, name, "state", "CREATED") + process = process.node + else: + process = launch.submit(PythonJob, inputs=inputs, parent_pid=pk) + self.set_task_state_info(pk, name, "state", "RUNNING") + process.label = name + process.workgraph_pk = pk + self.set_task_state_info(pk, name, "process", process) + self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["SHELLJOB"]: + from aiida_shell.calculations.shell import ShellJob + from .utils import prepare_for_shell_task + + inputs = prepare_for_shell_task(task, kwargs) + if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} is created and paused.", pk) + process = create_and_pause_process( + self.runner, + ShellJob, + inputs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(pk, name, "state", "CREATED") + process = process.node + else: + process = launch.submit(ShellJob, inputs=inputs, parent_pid=pk) + self.set_task_state_info(pk, name, "state", "RUNNING") + process.label = name + process.workgraph_pk = pk + self.set_task_state_info(pk, name, "process", process) + self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["WHILE"]: + # check the conditions of the while task + should_run = self.should_run_while_task(name) + if not should_run: + self.set_task_state_info(pk, name, "state", "FINISHED") + self.set_tasks_state( + pk, + self.ctx._workgraph[pk]["_tasks"][name]["children"], + "SKIPPED", + ) + self.update_parent_task_state(pk, name) + self.report( + f"While Task {name}: Condition not fullilled, task finished. Skip all its children.", + pk, + ) + else: + task["execution_count"] += 1 + self.set_task_state_info(pk, name, "state", "RUNNING") + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["IF"]: + should_run = self.should_run_if_task(name) + if should_run: + self.set_task_state_info(pk, name, "state", "RUNNING") + else: + self.set_tasks_state(pk, task["children"], "SKIPPED") + self.update_zone_task_state(pk, name) + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["ZONE"]: + self.set_task_state_info(pk, name, "state", "RUNNING") + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["FROM_CONTEXT"]: + # get the results from the context + results = {"result": getattr(self.ctx, kwargs["key"])} + task["results"] = results + self.set_task_state_info(pk, name, "state", "FINISHED") + self.update_parent_task_state(pk, name) + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["TO_CONTEXT"]: + # get the results from the context + setattr(self.ctx, kwargs["key"], kwargs["value"]) + self.set_task_state_info(pk, name, "state", "FINISHED") + self.update_parent_task_state(pk, name) + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["AWAITABLE"]: + for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]: + kwargs.pop(key, None) + awaitable_target = asyncio.ensure_future( + self.run_executor(executor, args, kwargs, var_args, var_kwargs), + loop=self.loop, + ) + awaitable = self.construct_awaitable_function(name, awaitable_target) + self.set_task_state_info(pk, name, "state", "RUNNING") + self.to_context(**{name: awaitable}) + elif task["metadata"]["node_type"].upper() in ["MONITOR"]: + + for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]: + kwargs.pop(key, None) + # add function and interval to the args + args = [executor, kwargs.pop("interval"), kwargs.pop("timeout"), *args] + awaitable_target = asyncio.ensure_future( + self.run_executor(monitor, args, kwargs, var_args, var_kwargs), + loop=self.loop, + ) + awaitable = self.construct_awaitable_function(name, awaitable_target) + self.set_task_state_info(pk, name, "state", "RUNNING") + # save the awaitable to the temp, so that we can kill it if needed + self._temp["awaitables"][name] = awaitable_target + self.to_context(**{name: awaitable}) + elif task["metadata"]["node_type"].upper() in ["NORMAL"]: + # Normal task is created by decoratoring a function with @task() + if "context" in task["metadata"]["kwargs"]: + self.ctx.task_name = name + kwargs.update({"context": self.ctx}) + for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]: + kwargs.pop(key, None) + try: + results = self.run_executor( + executor, args, kwargs, var_args, var_kwargs + ) + self.set_normal_task_results(pk, name, results) + except Exception as e: + self.logger.error(f"Error in task {name}: {e}") + self.set_task_state_info(pk, name, "state", "FAILED") + # set child tasks state to SKIPPED + self.set_tasks_state( + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", + ) + self.report(f"Task: {name} failed.", pk) + self.run_error_handlers(pk, name) + if continue_workgraph: + self.continue_workgraph(pk) + else: + # self.report("Unknow task type {}".format(task["metadata"]["node_type"])) + return self.exit_codes.UNKNOWN_TASK_TYPE + + def construct_awaitable_function( + self, name: str, awaitable_target: Awaitable + ) -> None: + """Construct the awaitable function.""" + awaitable = Awaitable( + **{ + "pk": name, + "action": AwaitableAction.ASSIGN, + "target": "asyncio.tasks.Task", + "outputs": False, + } + ) + awaitable_target.key = name + awaitable_target.pk = name + awaitable_target.action = AwaitableAction.ASSIGN + awaitable_target.add_done_callback(self._on_awaitable_finished) + return awaitable + + def get_inputs( + self, pk: int, name: str + ) -> t.Tuple[ + t.List[t.Any], + t.Dict[str, t.Any], + t.Optional[t.List[t.Any]], + t.Optional[t.Dict[str, t.Any]], + t.Dict[str, t.Any], + ]: + """Get input based on the links.""" + + args = [] + args_dict = {} + kwargs = {} + var_args = None + var_kwargs = None + task = self.ctx._workgraph[pk]["_tasks"][name] + properties = task.get("properties", {}) + inputs = {} + for input in task["inputs"]: + # print(f"input: {input['name']}") + if len(input["links"]) == 0: + inputs[input["name"]] = self.update_context_variable( + properties[input["name"]]["value"] + ) + elif len(input["links"]) == 1: + link = input["links"][0] + if ( + self.ctx._workgraph[pk]["_tasks"][link["from_node"]]["results"] + is None + ): + inputs[input["name"]] = None + else: + # handle the special socket _wait, _outputs + if link["from_socket"] == "_wait": + continue + elif link["from_socket"] == "_outputs": + inputs[input["name"]] = self.ctx._workgraph[pk]["_tasks"][ + link["from_node"] + ]["results"] + else: + inputs[input["name"]] = get_nested_dict( + self.ctx._workgraph[pk]["_tasks"][link["from_node"]][ + "results" + ], + link["from_socket"], + ) + # handle the case of multiple outputs + elif len(input["links"]) > 1: + value = {} + for link in input["links"]: + name = f'{link["from_node"]}_{link["from_socket"]}' + # handle the special socket _wait, _outputs + if link["from_socket"] == "_wait": + continue + if ( + self.ctx._workgraph[pk]["_tasks"][link["from_node"]]["results"] + is None + ): + value[name] = None + else: + value[name] = self.ctx._workgraph[pk]["_tasks"][ + link["from_node"] + ]["results"][link["from_socket"]] + inputs[input["name"]] = value + for name in task["metadata"].get("args", []): + if name in inputs: + args.append(inputs[name]) + args_dict[name] = inputs[name] + else: + value = self.update_context_variable(properties[name]["value"]) + args.append(value) + args_dict[name] = value + for name in task["metadata"].get("kwargs", []): + if name in inputs: + kwargs[name] = inputs[name] + else: + value = self.update_context_variable(properties[name]["value"]) + kwargs[name] = value + if task["metadata"]["var_args"] is not None: + name = task["metadata"]["var_args"] + if name in inputs: + var_args = inputs[name] + else: + value = self.update_context_variable(properties[name]["value"]) + var_args = value + if task["metadata"]["var_kwargs"] is not None: + name = task["metadata"]["var_kwargs"] + if name in inputs: + var_kwargs = inputs[name] + else: + value = self.update_context_variable(properties[name]["value"]) + var_kwargs = value + return args, kwargs, var_args, var_kwargs, args_dict + + def update_context_variable(self, value: t.Any) -> t.Any: + # replace context variables + + """Get value from context.""" + if isinstance(value, dict): + for key, sub_value in value.items(): + value[key] = self.update_context_variable(sub_value) + elif ( + isinstance(value, str) + and value.strip().startswith("{{") + and value.strip().endswith("}}") + ): + name = value[2:-2].strip() + return get_nested_dict(self.ctx, name) + return value + + def task_set_context(self, pk, name: str) -> None: + """Export task result to context.""" + from aiida_workgraph.utils import update_nested_dict + + items = self.ctx._workgraph[pk]["_tasks"][name]["context_mapping"] + for key, value in items.items(): + result = self.ctx._workgraph[pk]["_tasks"][name]["results"][key] + update_nested_dict(self.ctx, value, result) + + def is_task_ready_to_run(self, pk, name: str) -> t.Tuple[bool, t.Optional[str]]: + """Check if the task ready to run. + For normal task and a zone task, we need to check its input tasks in the connectivity["zone"]. + For task inside a zone, we need to check if the zone (parent task) is ready. + """ + parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"] + # input_tasks, parent_task, conditions + parent_states = [True, True] + # if the task belongs to a parent zone + if parent_task[0]: + state = self.get_task_state_info(pk, parent_task[0], "state") + if state not in ["RUNNING"]: + parent_states[1] = False + # check the input tasks of the zone + # check if the zone input tasks are ready + for child_task_name in self.ctx._workgraph[pk]["_connectivity"]["zone"][name][ + "input_tasks" + ]: + if self.get_task_state_info(pk, child_task_name, "state") not in [ + "FINISHED", + "SKIPPED", + "FAILED", + ]: + parent_states[0] = False + break + + return all(parent_states), parent_states + + def reset_workgraph(self, pk) -> None: + self.ctx._workgraph[pk]["_execution_count"] += 1 + self.set_tasks_state(pk, self.ctx._workgraph[pk]["_tasks"].keys(), "PLANNED") + self.ctx._workgraph[pk]["_executed_tasks"] = [] + + def set_tasks_state( + self, pk: int, tasks: t.Union[t.List[str], t.Sequence[str]], value: str + ) -> None: + """Set tasks state""" + for name in tasks: + self.set_task_state_info(pk, name, "state", value) + if "children" in self.ctx._workgraph[pk]["_tasks"][name]: + self.set_tasks_state( + pk, self.ctx._workgraph[pk]["_tasks"][name]["children"], value + ) + + def run_executor( + self, + executor: t.Callable, + args: t.List[t.Any], + kwargs: t.Dict[str, t.Any], + var_args: t.Optional[t.List[t.Any]], + var_kwargs: t.Optional[t.Dict[str, t.Any]], + ) -> t.Any: + if var_kwargs is None: + return executor(*args, **kwargs) + else: + return executor(*args, **kwargs, **var_kwargs) + + def save_results_to_extras(self, name: str) -> None: + """Save the results to the base.extras. + For the outputs of a Normal task, they are not saved to the database like the calcjob or workchain. + One temporary solution is to save the results to the base.extras. In order to do this, we need to + serialize the results + """ + from aiida_workgraph.utils import get_executor + + results = self.ctx._workgraph[pk]["_tasks"][name]["results"] + if results is None: + return + datas = {} + for key, value in results.items(): + # find outptus sockets with the name as key + output = [ + output + for output in self.ctx._workgraph[pk]["_tasks"][name]["outputs"] + if output["name"] == key + ] + if len(output) == 0: + continue + output = output[0] + Executor, _ = get_executor(output["serialize"]) + datas[key] = Executor(value) + self.node.set_extra(f"nodes__results__{name}", datas) + + def message_receive( + self, _comm: kiwipy.Communicator, msg: t.Dict[str, t.Any] + ) -> t.Any: + """ + Coroutine called when the process receives a message from the communicator + + :param _comm: the communicator that sent the message + :param msg: the message + :return: the outcome of processing the message, the return value will be sent back as a response to the sender + """ + self.logger.debug( + "Process<%s>: received RPC message with communicator '%s': %r", + self.pid, + _comm, + msg, + ) + + intent = msg[process_comms.INTENT_KEY] + + if intent == process_comms.Intent.PLAY: + return self._schedule_rpc(self.play) + if intent == process_comms.Intent.PAUSE: + return self._schedule_rpc( + self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None) + ) + if intent == process_comms.Intent.KILL: + return self._schedule_rpc( + self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None) + ) + if intent == process_comms.Intent.STATUS: + status_info: t.Dict[str, t.Any] = {} + self.get_status_info(status_info) + return status_info + if intent == "custom": + return self._schedule_rpc(self.apply_action, msg=msg) + + # Didn't match any known intents + raise RuntimeError("Unknown intent") + + def call_on_receive_workgraph_message(self, _comm, msg): + """Call on receive workgraph message.""" + # self.report(f"Received workgraph message: {msg}") + pk = msg["args"]["pid"] + # To avoid "DbNode is not persistent", we need to schedule the call + self._schedule_rpc(self.launch_workgraph, pk=pk) + return True + + def add_workgraph_subsriber(self) -> None: + """Add workgraph subscriber.""" + queue_name = "workgraph_queue" + self.report(f"Add workgraph subscriber on queue: {queue_name}") + comm = self.runner.communicator._communicator + queue = comm.task_queue(queue_name, prefetch_count=1000) + queue.add_task_subscriber(self.call_on_receive_workgraph_message) + + def finalize_workgraph(self, pk: int) -> t.Optional[ExitCode]: + """""" + # expose outputs of the workgraph + group_outputs = {} + for output in self.ctx._workgraph[pk]["_metadata"]["group_outputs"]: + names = output["from"].split(".", 1) + if names[0] == "context": + if len(names) == 1: + raise ValueError("The output name should be context.key") + update_nested_dict( + group_outputs, + output["name"], + get_nested_dict(self.ctx, names[1]), + ) + else: + # expose the whole outputs of the tasks + if len(names) == 1: + update_nested_dict( + group_outputs, + output["name"], + self.ctx._workgraph[pk]["_tasks"][names[0]]["results"], + ) + else: + # expose one output of the task + # note, the output may not exist + if ( + names[1] + in self.ctx._workgraph[pk]["_tasks"][names[0]]["results"] + ): + update_nested_dict( + group_outputs, + output["name"], + self.ctx._workgraph[pk]["_tasks"][names[0]]["results"][ + names[1] + ], + ) + # output the new data + if self.ctx._workgraph[pk]["_new_data"]: + group_outputs["new_data"] = self.ctx._workgraph[pk]["_new_data"] + group_outputs["execution_count"] = orm.Int( + self.ctx._workgraph[pk]["_execution_count"] + ).store() + self.update_workgraph_output(pk, group_outputs) + self.report("Finalize workgraph.", pk) + self.report(f"Finalize workgraph {pk}.") + for name, task in self.ctx._workgraph[pk]["_tasks"].items(): + if self.get_task_state_info(pk, name, "state") == "FAILED": + self.report(f"Task {name} failed.", pk) + self.ctx._workgraph[pk]["_node"].set_exit_status(302) + self.ctx._workgraph[pk]["_node"].set_exit_message( + "Some of the tasks failed." + ) + self.broadcast_workgraph_state(pk, "excepted") + self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL) + self.ctx._workgraph[pk]["_node"].seal() + return + # send a broadcast message to announce the workgraph is finished + self.broadcast_workgraph_state(pk, "finished") + self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL) + self.ctx._workgraph[pk]["_node"].seal() + + def update_workgraph_output(self, pk, outputs: dict) -> None: + """Update the workgraph output.""" + from aiida.common.links import LinkType + + node = self.ctx._workgraph[pk]["_node"] + + def add_link(node, outputs): + for link_label, output in outputs.items(): + if isinstance(output, dict): + add_link(node, output) + if isinstance(node, orm.CalculationNode): + output.base.links.add_incoming(node, LinkType.CREATE, link_label) + elif isinstance(node, orm.WorkflowNode): + output.base.links.add_incoming(node, LinkType.RETURN, link_label) + + add_link(node, outputs) + + def broadcast_workgraph_state(self, pk: int, state: str) -> None: + """Workgraph on entered.""" + from aio_pika.exceptions import ConnectionClosed + + from_label = None + subject = f"state_changed.{from_label}.{state}" + try: + self._communicator.broadcast_send(body=None, sender=pk, subject=subject) + except ConnectionClosed: + message = "Process<%s>: no connection available to broadcast state change from %s to %s" + self.logger.warning(message, pk, from_label, state) + except kiwipy.TimeoutError: + message = ( + "Process<%s>: sending broadcast of state change from %s to %s timed out" + ) + self.logger.warning(message, pk, from_label, state) diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index 82332a78..e35f4fc5 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -1,6 +1,15 @@ from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes from aiida import orm from aiida.common.extendeddicts import AttributeDict +from aiida.engine.utils import is_process_function +from aiida.engine.processes.builder import ProcessBuilder +from aiida.engine.processes.process import Process + +import inspect +from typing import ( + Type, + Union, +) def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: @@ -139,3 +148,45 @@ def prepare_for_shell_task(task: dict, kwargs: dict) -> dict: "metadata": metadata or {}, } return inputs + + +def instantiate_process( + runner: "Runner", + process: Union["Process", Type["Process"], "ProcessBuilder"], + parent_pid=None, + **inputs, +) -> "Process": + """Return an instance of the process with the given inputs. The function can deal with various types + of the `process`: + + * Process instance: will simply return the instance + * ProcessBuilder instance: will instantiate the Process from the class and inputs defined within it + * Process class: will instantiate with the specified inputs + + If anything else is passed, a ValueError will be raised + + :param process: Process instance or class, CalcJobNode class or ProcessBuilder instance + :param inputs: the inputs for the process to be instantiated with + """ + + if isinstance(process, Process): + assert not inputs + assert runner is process.runner + return process + + if isinstance(process, ProcessBuilder): + builder = process + process_class = builder.process_class + inputs.update(**builder._inputs(prune=True)) + elif is_process_function(process): + process_class = process.process_class # type: ignore[attr-defined] + elif inspect.isclass(process) and issubclass(process, Process): + process_class = process + else: + raise ValueError( + f"invalid process {type(process)}, needs to be Process or ProcessBuilder" + ) + + process = process_class(runner=runner, inputs=inputs, parent_pid=parent_pid) + + return process diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index f2eacd23..46e803ba 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -60,6 +60,7 @@ def __init__( logger: logging.Logger | None = None, runner: "Runner" | None = None, enable_persistence: bool = True, + **kwargs: t.Any, ) -> None: """Construct a WorkGraph instance. @@ -70,7 +71,9 @@ def __init__( """ - super().__init__(inputs, logger, runner, enable_persistence=enable_persistence) + super().__init__( + inputs, logger, runner, enable_persistence=enable_persistence, **kwargs + ) self._awaitables: list[Awaitable] = [] self._context = AttributeDict() diff --git a/aiida_workgraph/tasks/test.py b/aiida_workgraph/tasks/test.py index e0c2bfa4..78ab58f7 100644 --- a/aiida_workgraph/tasks/test.py +++ b/aiida_workgraph/tasks/test.py @@ -1,5 +1,6 @@ from typing import Dict from aiida_workgraph.task import Task +from aiida_workgraph.decorator import task class TestAdd(Task): @@ -114,3 +115,19 @@ def get_executor(self) -> Dict[str, str]: "name": "core.arithmetic.multiply_add", "type": "WorkflowFactory", } + + +@task.graph_builder() +def create_workgraph_recrusively(n): + from aiida_workgraph import WorkGraph + from aiida_workgraph.tasks.test import create_workgraph_recrusively + from aiida.calculations.arithmetic.add import ArithmeticAddCalculation + from aiida import orm + + code = orm.load_code("add@localhost") + wg = WorkGraph(f"n-{n}") + if n > 0: + wg.add_task(create_workgraph_recrusively, name=f"n_{n}", n=n - 1) + else: + wg.add_task(ArithmeticAddCalculation, code=code, x=1, y=2) + return wg diff --git a/aiida_workgraph/utils/control.py b/aiida_workgraph/utils/control.py index 376f8fc3..4fa50be6 100644 --- a/aiida_workgraph/utils/control.py +++ b/aiida_workgraph/utils/control.py @@ -1,6 +1,7 @@ from aiida.manage import get_manager from aiida import orm from aiida.engine.processes import control +from aiida_workgraph.engine.override import ControllerWithQueueName def create_task_action( @@ -17,6 +18,30 @@ def create_task_action( controller._communicator.rpc_send(pk, message) +def create_scheduler_action( + pk: int, +): + """Send workgraph task to scheduler.""" + + manager = get_manager() + controller = ControllerWithQueueName( + queue_name="scheduler_queue", communicator=manager.get_communicator() + ) + controller.continue_process(pk, nowait=False) + + +def create_workgraph_action( + pk: int, +): + """Send workgraph task to scheduler.""" + + manager = get_manager() + controller = ControllerWithQueueName( + queue_name="workgraph_queue", communicator=manager.get_communicator() + ) + controller.continue_process(pk, nowait=False) + + def get_task_state_info(node, name: str, key: str) -> str: """Get task state info from base.extras.""" from aiida.orm.utils.serialize import deserialize_unsafe diff --git a/aiida_workgraph/web/backend/app/api.py b/aiida_workgraph/web/backend/app/api.py index d05b527b..4bf69c91 100644 --- a/aiida_workgraph/web/backend/app/api.py +++ b/aiida_workgraph/web/backend/app/api.py @@ -2,6 +2,7 @@ from fastapi.middleware.cors import CORSMiddleware from aiida.manage import manager from aiida_workgraph.web.backend.app.daemon import router as daemon_router +from aiida_workgraph.web.backend.app.scheduler import router as scheduler_router from aiida_workgraph.web.backend.app.workgraph import router as workgraph_router from aiida_workgraph.web.backend.app.datanode import router as datanode_router from fastapi.staticfiles import StaticFiles @@ -47,6 +48,7 @@ async def read_root() -> dict: app.include_router(workgraph_router) app.include_router(datanode_router) app.include_router(daemon_router) +app.include_router(scheduler_router) @app.get("/debug") diff --git a/aiida_workgraph/web/backend/app/daemon.py b/aiida_workgraph/web/backend/app/daemon.py index af069d7c..caa22cf4 100644 --- a/aiida_workgraph/web/backend/app/daemon.py +++ b/aiida_workgraph/web/backend/app/daemon.py @@ -22,7 +22,7 @@ class DaemonStatusModel(BaseModel): ) -@router.get("/api/daemon/status", response_model=DaemonStatusModel) +@router.get("/api/daemon/task/status", response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_status() -> DaemonStatusModel: """Return the daemon status.""" @@ -36,7 +36,7 @@ async def get_daemon_status() -> DaemonStatusModel: return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) -@router.get("/api/daemon/worker") +@router.get("/api/daemon/task/worker") @with_dbenv() async def get_daemon_worker(): """Return the daemon status.""" @@ -50,7 +50,7 @@ async def get_daemon_worker(): return response["info"] -@router.post("/api/daemon/start", response_model=DaemonStatusModel) +@router.post("/api/daemon/task/start", response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_start() -> DaemonStatusModel: """Start the daemon.""" @@ -69,7 +69,7 @@ async def get_daemon_start() -> DaemonStatusModel: return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) -@router.post("/api/daemon/stop", response_model=DaemonStatusModel) +@router.post("/api/daemon/task/stop", response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_stop() -> DaemonStatusModel: """Stop the daemon.""" @@ -86,7 +86,7 @@ async def get_daemon_stop() -> DaemonStatusModel: return DaemonStatusModel(running=False, num_workers=None) -@router.post("/api/daemon/increase", response_model=DaemonStatusModel) +@router.post("/api/daemon/task/increase", response_model=DaemonStatusModel) @with_dbenv() async def increase_daemon_worker() -> DaemonStatusModel: """increase the daemon worker.""" @@ -103,7 +103,7 @@ async def increase_daemon_worker() -> DaemonStatusModel: return DaemonStatusModel(running=False, num_workers=None) -@router.post("/api/daemon/decrease", response_model=DaemonStatusModel) +@router.post("/api/daemon/task/decrease", response_model=DaemonStatusModel) @with_dbenv() async def decrease_daemon_worker() -> DaemonStatusModel: """decrease the daemon worker.""" diff --git a/aiida_workgraph/web/backend/app/scheduler.py b/aiida_workgraph/web/backend/app/scheduler.py new file mode 100644 index 00000000..c4110402 --- /dev/null +++ b/aiida_workgraph/web/backend/app/scheduler.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +"""Declaration of FastAPI router for daemon endpoints.""" +from __future__ import annotations + +import typing as t + +from aiida.cmdline.utils.decorators import with_dbenv +from aiida.engine.daemon.client import DaemonException +from aiida_workgraph.engine.scheduler.client import get_scheduler_client +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field +from aiida_workgraph.engine.scheduler.client import start_scheduler_process + + +router = APIRouter() + + +class DaemonStatusModel(BaseModel): + """Response model for daemon status.""" + + running: bool = Field(description="Whether the daemon is running or not.") + num_workers: t.Optional[int] = Field( + description="The number of workers if the daemon is running." + ) + + +@router.get("/api/daemon/scheduler/status", response_model=DaemonStatusModel) +@with_dbenv() +async def get_daemon_status() -> DaemonStatusModel: + """Return the daemon status.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + return DaemonStatusModel(running=False, num_workers=None) + + response = client.get_numprocesses() + + return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) + + +@router.get("/api/daemon/scheduler/worker") +@with_dbenv() +async def get_daemon_worker(): + """Return the daemon status.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + return {} + + response = client.get_worker_info() + + return response["info"] + + +@router.post("/api/daemon/scheduler/start", response_model=DaemonStatusModel) +@with_dbenv() +async def get_daemon_start() -> DaemonStatusModel: + """Start the daemon.""" + client = get_scheduler_client() + + if client.is_daemon_running: + raise HTTPException(status_code=400, detail="The daemon is already running.") + + try: + client.start_daemon() + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + response = client.get_numprocesses() + start_scheduler_process() + + return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) + + +@router.post("/api/daemon/scheduler/stop", response_model=DaemonStatusModel) +@with_dbenv() +async def get_daemon_stop() -> DaemonStatusModel: + """Stop the daemon.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + raise HTTPException(status_code=400, detail="The daemon is not running.") + + try: + client.stop_daemon() + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + return DaemonStatusModel(running=False, num_workers=None) + + +@router.post("/api/daemon/scheduler/increase", response_model=DaemonStatusModel) +@with_dbenv() +async def increase_daemon_worker() -> DaemonStatusModel: + """increase the daemon worker.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + raise HTTPException(status_code=400, detail="The daemon is not running.") + + try: + client.increase_workers(1) + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + response = client.get_numprocesses() + print(response) + start_scheduler_process(response["numprocesses"]) + + return DaemonStatusModel(running=False, num_workers=None) + + +@router.post("/api/daemon/scheduler/decrease", response_model=DaemonStatusModel) +@with_dbenv() +async def decrease_daemon_worker() -> DaemonStatusModel: + """decrease the daemon worker.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + raise HTTPException(status_code=400, detail="The daemon is not running.") + + try: + client.decrease_workers(1) + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + return DaemonStatusModel(running=False, num_workers=None) diff --git a/aiida_workgraph/web/frontend/src/components/Settings.js b/aiida_workgraph/web/frontend/src/components/Settings.js index 23310618..50a89fee 100644 --- a/aiida_workgraph/web/frontend/src/components/Settings.js +++ b/aiida_workgraph/web/frontend/src/components/Settings.js @@ -3,79 +3,131 @@ import { ToastContainer, toast } from 'react-toastify'; import 'react-toastify/dist/ReactToastify.css'; function Settings() { - const [workers, setWorkers] = useState([]); + const [taskWorkers, setTaskWorkers] = useState([]); + const [schedulerWorkers, setSchedulerWorkers] = useState([]); - const fetchWorkers = () => { - fetch('http://localhost:8000/api/daemon/worker') + // Fetching task workers + const fetchTaskWorkers = () => { + fetch('http://localhost:8000/api/daemon/task/worker') .then(response => response.json()) - .then(data => setWorkers(Object.values(data))) - .catch(error => console.error('Failed to fetch workers:', error)); + .then(data => setTaskWorkers(Object.values(data))) + .catch(error => console.error('Failed to fetch task workers:', error)); + }; + + // Fetching scheduler workers + const fetchSchedulerWorkers = () => { + fetch('http://localhost:8000/api/daemon/scheduler/worker') + .then(response => response.json()) + .then(data => setSchedulerWorkers(Object.values(data))) + .catch(error => console.error('Failed to fetch scheduler workers:', error)); }; useEffect(() => { - fetchWorkers(); - const interval = setInterval(fetchWorkers, 1000); // Poll every 5 seconds - return () => clearInterval(interval); // Clear interval on component unmount + fetchTaskWorkers(); + fetchSchedulerWorkers(); + const taskInterval = setInterval(fetchTaskWorkers, 1000); + const schedulerInterval = setInterval(fetchSchedulerWorkers, 1000); + return () => { + clearInterval(taskInterval); + clearInterval(schedulerInterval); + }; // Clear intervals on component unmount }, []); - const handleDaemonControl = (action) => { - fetch(`http://localhost:8000/api/daemon/${action}`, { method: 'POST' }) + const handleDaemonControl = (daemonType, action) => { + fetch(`http://localhost:8000/api/daemon/${daemonType}/${action}`, { method: 'POST' }) .then(response => { if (!response.ok) { - throw new Error(`Daemon operation failed: ${response.statusText}`); + throw new Error(`${daemonType} daemon operation failed: ${response.statusText}`); } return response.json(); }) - .then(data => { - toast.success(`Daemon ${action}ed successfully`); - fetchWorkers(); + .then(() => { + toast.success(`${daemonType} daemon ${action}ed successfully`); + if (daemonType === 'task') { + fetchTaskWorkers(); + } else { + fetchSchedulerWorkers(); + } }) .catch(error => toast.error(error.message)); }; - const adjustWorkers = (action) => { - fetch(`http://localhost:8000/api/daemon/${action}`, { method: 'POST' }) + const adjustWorkers = (daemonType, action) => { + fetch(`http://localhost:8000/api/daemon/${daemonType}/${action}`, { method: 'POST' }) .then(response => { if (!response.ok) { - throw new Error(`Failed to ${action} workers: ${response.statusText}`); + throw new Error(`Failed to ${action} workers for ${daemonType}: ${response.statusText}`); } return response.json(); }) - .then(data => { - toast.success(`Workers ${action}ed successfully`); - fetchWorkers(); // Refetch workers after adjusting + .then(() => { + toast.success(`${daemonType} Workers ${action}ed successfully`); + if (daemonType === 'task') { + fetchTaskWorkers(); + } else { + fetchSchedulerWorkers(); + } }) .catch(error => toast.error(error.message)); }; return (
-

Daemon Control

- - - - - - - - - - - {workers.map(worker => ( - - - - - +
+

Task Daemon Control

+ + + + +
PIDMemory %CPU %Started
{worker.pid}{worker.mem}{worker.cpu}{new Date(worker.started * 1000).toLocaleString()}
+ + + + + + + + + + {taskWorkers.map(worker => ( + + + + + + + ))} + +
PIDMemory %CPU %Started
{worker.pid}{worker.mem}{worker.cpu}{new Date(worker.started * 1000).toLocaleString()}
+
+
+

Scheduler Daemon Control

+ + + + + + + + + + + - ))} - -
PIDMemory %CPU %Started
- - - - + + + {schedulerWorkers.map(worker => ( + + {worker.pid} + {worker.mem} + {worker.cpu} + {new Date(worker.started * 1000).toLocaleString()} + + ))} + + +
); } diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 339435a0..e99a34f6 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -115,6 +115,7 @@ def submit( wait: bool = False, timeout: int = 60, metadata: Optional[Dict[str, Any]] = None, + to_scheduler: bool = False, ) -> aiida.orm.ProcessNode: """Submit the AiiDA workgraph process and optionally wait for it to finish. Args: @@ -134,27 +135,34 @@ def submit( self.save(metadata=metadata) if self.process.process_state.value.upper() not in ["CREATED"]: raise ValueError(f"Process {self.process.pk} has already been submitted.") - self.continue_process() + if to_scheduler: + self.continue_process_in_scheduler(to_scheduler) + else: + self.continue_process() # as long as we submit the process, it is a new submission, we should set restart_process to None self.restart_process = None if wait: self.wait(timeout=timeout) return self.process - def save(self, metadata: Optional[Dict[str, Any]] = None) -> None: + def save( + self, metadata: Optional[Dict[str, Any]] = None, parent_pid: int = None + ) -> None: """Save the udpated workgraph to the process This is only used for a running workgraph. Save the AiiDA workgraph process and update the process status. """ from aiida.manage import manager - from aiida.engine.utils import instantiate_process + from aiida_workgraph.engine.utils import instantiate_process from aiida_workgraph.engine.workgraph import WorkGraphEngine inputs = self.prepare_inputs(metadata) if self.process is None: runner = manager.get_manager().get_runner() # init a process node - process_inited = instantiate_process(runner, WorkGraphEngine, **inputs) + process_inited = instantiate_process( + runner, WorkGraphEngine, parent_pid=parent_pid, **inputs + ) process_inited.runner.persister.save_checkpoint(process_inited) self.process = process_inited.node self.process_inited = process_inited @@ -415,6 +423,24 @@ def continue_process(self): process_controller = get_manager().get_process_controller() process_controller.continue_process(self.pk) + def continue_process_in_scheduler(self, to_scheduler: Union[int, bool]) -> None: + """Ask the scheduler to pick up the process from the database and run it. + If to_scheduler is an integer, it will be used as the scheduler pk. + Otherwise, it will send the message to the queue, and the scheduler will pick it up. + """ + from aiida_workgraph.utils.control import ( + create_task_action, + create_workgraph_action, + ) + + try: + if isinstance(to_scheduler, int) and not isinstance(to_scheduler, bool): + create_task_action(to_scheduler, [self.pk], action="launch_workgraph") + else: + create_workgraph_action(self.pk) + except Exception as e: + print("""An unexpected error occurred:""", e) + def play(self): import os diff --git a/docs/source/howto/index.rst b/docs/source/howto/index.rst index 274f74be..3991edfe 100644 --- a/docs/source/howto/index.rst +++ b/docs/source/howto/index.rst @@ -24,5 +24,6 @@ This section contains a collection of HowTos for various topics. queue cli control + scheduler transfer_workchain workchain_call_workgraph diff --git a/docs/source/howto/scheduler.ipynb b/docs/source/howto/scheduler.ipynb new file mode 100644 index 00000000..e88f2f50 --- /dev/null +++ b/docs/source/howto/scheduler.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scheduler\n", + "\n", + "## Overview\n", + "\n", + "This documentation provides a guide on using the `aiida-workgraph` Scheduler to manage `WorkGraph` processes efficiently.\n", + "\n", + "### Background\n", + "\n", + "Traditional workflow processes, particularly in nested structures like `PwBandsWorkChain`, tend to create multiple Workflow processes in a waiting state, while only a few `CalcJob` processes run actively. This results in inefficient resource usage. The `WorkChain` structure makes it challenging to eliminate these waiting processes due to its encapsulated logic.\n", + "\n", + "In contrast, the `WorkGraph` offers a more clear task dependency and allow other process to run its tasks in a controllable way. In a scheduler, one only need create the `WorkGraph` process in the database, not run it via a daemon worker.\n", + "\n", + "### Process Comparison: `PwBands` Case\n", + "\n", + "- **Old Approach**: 300 Workflow processes (Bands, Relax, Base) + 100 CalcJob processes.\n", + "- **New Approach**: 1 Scheduler process + 100 CalcJob processes.\n", + "\n", + "This new approach significantly reduces the number of active processes and mitigates the risk of deadlocks.\n", + "\n", + "## Getting Started with the Scheduler\n", + "\n", + "### Starting the Scheduler\n", + "\n", + "To launch a scheduler daemon:\n", + "\n", + "```console\n", + "workgraph scheduler start\n", + "```\n", + "\n", + "### Monitoring the Scheduler\n", + "\n", + "To check the current status of the scheduler:\n", + "\n", + "```console\n", + "workgraph scheduler status\n", + "```\n", + "\n", + "### Stopping the Scheduler\n", + "\n", + "To stop the scheduler daemon:\n", + "\n", + "```console\n", + "workgraph scheduler stop\n", + "```\n", + "\n", + "## Submitting WorkGraphs to the Scheduler\n", + "\n", + "To submit a WorkGraph to the scheduler, set the `to_scheduler` flag to `True`:\n", + "\n", + "```python\n", + "wg.submit(to_scheduler=True)\n", + "```\n", + "\n", + "\n", + "### Using Multiple Schedulers\n", + "\n", + "For environments with a high volume of WorkGraphs, starting multiple schedulers can enhance throughput:\n", + "\n", + "```console\n", + "workgraph scheduler start 2\n", + "```\n", + "\n", + "WorkGraphs will be automatically distributed among available schedulers.\n", + "\n", + "#### Specifying a Scheduler\n", + "\n", + "To submit a WorkGraph to a specific scheduler using its primary key (`pk`):\n", + "\n", + "```python\n", + "wg.submit(to_scheduler=pk_scheduler)\n", + "```\n", + "\n", + "### Best Practices for Scheduler Usage\n", + "\n", + "While a single scheduler suffices for most use cases, scaling up the number of schedulers may be beneficial when significantly increasing the number of task workers (created by `verdi daemon start`). A general rule is to maintain a ratio of less than 5 workers per scheduler.\n", + "\n", + "## Example" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WorkGraph process created, PK: 142617\n", + "State of WorkGraph : FINISHED\n", + "Result of add2 : 4\n" + ] + } + ], + "source": [ + "from aiida_workgraph import WorkGraph, task\n", + "from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n", + "from aiida import load_profile, orm\n", + "\n", + "load_profile()\n", + "\n", + "@task.calcfunction()\n", + "def add(x, y):\n", + " return x + y\n", + "\n", + "code = orm.load_code(\"add@localhost\")\n", + "\n", + "wg = WorkGraph(f\"test_scheduler\")\n", + "add1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\", x=1, y=2, code=code)\n", + "add2 = wg.add_task(ArithmeticAddCalculation, name=\"add2\", x=1, y=add1.outputs[\"sum\"], code=code)\n", + "wg.submit(to_scheduler=True,\n", + " wait=True)\n", + "print(\"State of WorkGraph : {}\".format(wg.state))\n", + "print('Result of add2 : {}'.format(wg.tasks[\"add2\"].node.outputs.sum.value))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate node graph from the AiiDA process,and we can see the provenance graph of the workgraph:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "N142617\n", + "\n", + "WorkGraph<test_scheduler> (142617)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N142620\n", + "\n", + "ArithmeticAddCalculation (142620)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N142617->N142620\n", + "\n", + "\n", + "CALL_CALC\n", + "add1\n", + "\n", + "\n", + "\n", + "N142625\n", + "\n", + "ArithmeticAddCalculation (142625)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N142617->N142625\n", + "\n", + "\n", + "CALL_CALC\n", + "add2\n", + "\n", + "\n", + "\n", + "N37\n", + "\n", + "InstalledCode (37)\n", + "add@localhost\n", + "\n", + "\n", + "\n", + "N37->N142617\n", + "\n", + "\n", + "INPUT_WORK\n", + "wg__tasks__add1__properties__code__value\n", + "\n", + "\n", + "\n", + "N37->N142617\n", + "\n", + "\n", + "INPUT_WORK\n", + "wg__tasks__add2__properties__code__value\n", + "\n", + "\n", + "\n", + "N142621\n", + "\n", + "RemoteData (142621)\n", + "@localhost\n", + "\n", + "\n", + "\n", + "N142620->N142621\n", + "\n", + "\n", + "CREATE\n", + "remote_folder\n", + "\n", + "\n", + "\n", + "N142622\n", + "\n", + "FolderData (142622)\n", + "\n", + "\n", + "\n", + "N142620->N142622\n", + "\n", + "\n", + "CREATE\n", + "retrieved\n", + "\n", + "\n", + "\n", + "N142623\n", + "\n", + "Int (142623)\n", + "\n", + "\n", + "\n", + "N142620->N142623\n", + "\n", + "\n", + "CREATE\n", + "sum\n", + "\n", + "\n", + "\n", + "N142623->N142625\n", + "\n", + "\n", + "INPUT_CALC\n", + "y\n", + "\n", + "\n", + "\n", + "N142626\n", + "\n", + "RemoteData (142626)\n", + "@localhost\n", + "\n", + "\n", + "\n", + "N142625->N142626\n", + "\n", + "\n", + "CREATE\n", + "remote_folder\n", + "\n", + "\n", + "\n", + "N142627\n", + "\n", + "FolderData (142627)\n", + "\n", + "\n", + "\n", + "N142625->N142627\n", + "\n", + "\n", + "CREATE\n", + "retrieved\n", + "\n", + "\n", + "\n", + "N142628\n", + "\n", + "Int (142628)\n", + "\n", + "\n", + "\n", + "N142625->N142628\n", + "\n", + "\n", + "CREATE\n", + "sum\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from aiida_workgraph.utils import generate_node_graph\n", + "generate_node_graph(wg.pk)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Checkpointing\n", + "\n", + "The Scheduler checkpoints its status to the database whenever a WorkGraph is updated, ensuring that the Scheduler can recover its state in case of a crash or restart. This feature is particularly useful for long-running WorkGraphs.\n", + "\n", + "## Conclusion\n", + "\n", + "The Scheduler offers a streamlined approach to managing complex workflows, significantly reducing active process counts and improving resource efficiency." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aiida", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index b67eee61..4f76b8aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points.'aiida.workflows'] "workgraph.engine" = "aiida_workgraph.engine.workgraph:WorkGraphEngine" +"workgraph.scheduler" = "aiida_workgraph.engine.scheduler.scheduler:WorkGraphScheduler" [project.entry-points."aiida.data"] "workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData" diff --git a/tests/conftest.py b/tests/conftest.py index 36b1fe61..efb2d1f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,46 @@ def fixture_localhost(aiida_localhost): return localhost +@pytest.fixture(scope="session") +def scheduler_client(aiida_profile): + """Return a daemon client for the configured test profile for the test session. + + The daemon will be automatically stopped at the end of the test session. + """ + from aiida_workgraph.engine.scheduler.client import SchedulerClient + from aiida.engine.daemon.client import ( + DaemonNotRunningException, + DaemonTimeoutException, + ) + + scheduler_client = SchedulerClient(aiida_profile) + + try: + yield scheduler_client + finally: + try: + scheduler_client.stop_daemon(wait=True) + except DaemonNotRunningException: + pass + # Give an additional grace period by manually waiting for the daemon to be stopped. In certain unit test + # scenarios, the built in wait time in ``scheduler_client.stop_daemon`` is not sufficient and even though the + # daemon is stopped, ``scheduler_client.is_daemon_running`` will return false for a little bit longer. + scheduler_client._await_condition( + lambda: not scheduler_client.is_daemon_running, + DaemonTimeoutException("The daemon failed to stop."), + ) + + +@pytest.fixture() +def started_scheduler_client(scheduler_client): + """Ensure that the daemon is running for the test profile and return the associated client.""" + if not scheduler_client.is_daemon_running: + scheduler_client.start_daemon() + assert scheduler_client.is_daemon_running + + yield scheduler_client + + @pytest.fixture def add_code(fixture_localhost): from aiida.orm import InstalledCode diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 00000000..c6308551 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,21 @@ +import pytest +from typing import Callable +from aiida_workgraph import WorkGraph +from aiida.cmdline.utils.common import get_workchain_report +from aiida_workgraph.engine.scheduler.client import get_scheduler +from aiida import orm + + +@pytest.mark.skip("Skip for now") +@pytest.mark.usefixtures("started_daemon_client") +def test_scheduler(decorated_add: Callable, started_scheduler_client) -> None: + """Test graph build.""" + wg = WorkGraph("test_scheduler") + add1 = wg.add_task(decorated_add, x=2, y=3) + add2 = wg.add_task(decorated_add, "add2", x=3, y=add1.outputs["result"]) + # use run to check if graph builder workgraph can be submit inside the engine + wg.submit(to_scheduler=True, wait=True) + pk = get_scheduler() + report = get_workchain_report(orm.load(pk), "REPORT") + print("report: ", report) + assert add2.outputs["result"].value == 8