diff --git a/aiida_workgraph/cli/__init__.py b/aiida_workgraph/cli/__init__.py
index db1f1f89..437e58d7 100644
--- a/aiida_workgraph/cli/__init__.py
+++ b/aiida_workgraph/cli/__init__.py
@@ -5,6 +5,7 @@
from aiida_workgraph.cli import cmd_graph
from aiida_workgraph.cli import cmd_web
from aiida_workgraph.cli import cmd_task
+from aiida_workgraph.cli import cmd_scheduler
-__all__ = ["cmd_graph", "cmd_web", "cmd_task"]
+__all__ = ["cmd_graph", "cmd_web", "cmd_task", "cmd_scheduler"]
diff --git a/aiida_workgraph/cli/cmd_scheduler.py b/aiida_workgraph/cli/cmd_scheduler.py
new file mode 100644
index 00000000..52cb54d3
--- /dev/null
+++ b/aiida_workgraph/cli/cmd_scheduler.py
@@ -0,0 +1,163 @@
+from aiida_workgraph.cli.cmd_workgraph import workgraph
+import click
+from aiida.cmdline.utils import decorators, echo
+from aiida.cmdline.commands.cmd_daemon import validate_daemon_workers
+from aiida.cmdline.params import options
+from aiida_workgraph.engine.scheduler.client import get_scheduler_client
+import sys
+
+
+@workgraph.group("scheduler")
+def scheduler():
+ """Commands to manage the scheduler process."""
+
+
+@scheduler.command()
+def worker():
+ """Start the scheduler application."""
+ from aiida_workgraph.engine.scheduler.client import start_scheduler_worker
+
+ click.echo("Starting the scheduler worker...")
+
+ start_scheduler_worker()
+
+
+@scheduler.command()
+@click.option("--foreground", is_flag=True, help="Run in foreground.")
+@click.argument("number", required=False, type=int, callback=validate_daemon_workers)
+@options.TIMEOUT(default=None, required=False, type=int)
+@decorators.with_dbenv()
+@decorators.requires_broker
+@decorators.check_circus_zmq_version
+def start(foreground, number, timeout):
+ """Start the scheduler application."""
+ from aiida_workgraph.engine.scheduler.client import start_scheduler_process
+
+ click.echo("Starting the scheduler process...")
+
+ client = get_scheduler_client()
+ client.start_daemon(number_workers=number, foreground=foreground, timeout=timeout)
+ start_scheduler_process(number)
+
+
+@scheduler.command()
+@click.option("--no-wait", is_flag=True, help="Do not wait for confirmation.")
+@click.option("--all", "all_profiles", is_flag=True, help="Stop all daemons.")
+@options.TIMEOUT(default=None, required=False, type=int)
+@decorators.requires_broker
+@click.pass_context
+def stop(ctx, no_wait, all_profiles, timeout):
+ """Stop the scheduler daemon.
+
+ Returns exit code 0 if the daemon was shut down successfully (or was not running), non-zero if there was an error.
+ """
+ if all_profiles is True:
+ profiles = [
+ profile
+ for profile in ctx.obj.config.profiles
+ if not profile.is_test_profile
+ ]
+ else:
+ profiles = [ctx.obj.profile]
+
+ for profile in profiles:
+ echo.echo("Profile: ", fg=echo.COLORS["report"], bold=True, nl=False)
+ echo.echo(f"{profile.name}", bold=True)
+ echo.echo("Stopping the daemon... ", nl=False)
+ try:
+ client = get_scheduler_client()
+ client.stop_daemon(wait=not no_wait, timeout=timeout)
+ except Exception as exception:
+ echo.echo_error(f"Failed to stop the daemon: {exception}")
+
+
+@scheduler.command(hidden=True)
+@click.option("--foreground", is_flag=True, help="Run in foreground.")
+@click.argument("number", required=False, type=int, callback=validate_daemon_workers)
+@decorators.with_dbenv()
+@decorators.requires_broker
+@decorators.check_circus_zmq_version
+def start_circus(foreground, number):
+ """This will actually launch the circus daemon, either daemonized in the background or in the foreground.
+
+ If run in the foreground all logs are redirected to stdout.
+
+ .. note:: this should not be called directly from the commandline!
+ """
+
+ get_scheduler_client()._start_daemon(number_workers=number, foreground=foreground)
+
+
+@scheduler.command()
+@click.option("--all", "all_profiles", is_flag=True, help="Show status of all daemons.")
+@options.TIMEOUT(default=None, required=False, type=int)
+@click.pass_context
+@decorators.requires_loaded_profile()
+@decorators.requires_broker
+def status(ctx, all_profiles, timeout):
+ """Print the status of the scheduler daemon.
+
+ Returns exit code 0 if all requested daemons are running, else exit code 3.
+ """
+ from tabulate import tabulate
+
+ from aiida.cmdline.utils.common import format_local_time
+ from aiida.engine.daemon.client import DaemonException
+
+ if all_profiles is True:
+ profiles = [
+ profile
+ for profile in ctx.obj.config.profiles
+ if not profile.is_test_profile
+ ]
+ else:
+ profiles = [ctx.obj.profile]
+
+ daemons_running = []
+
+ for profile in profiles:
+ client = get_scheduler_client(profile.name)
+ echo.echo("Profile: ", fg=echo.COLORS["report"], bold=True, nl=False)
+ echo.echo(f"{profile.name}", bold=True)
+
+ try:
+ client.get_status(timeout=timeout)
+ except DaemonException as exception:
+ echo.echo_error(str(exception))
+ daemons_running.append(False)
+ continue
+
+ worker_response = client.get_worker_info()
+ daemon_response = client.get_daemon_info()
+
+ workers = []
+ for pid, info in worker_response["info"].items():
+ if isinstance(info, dict):
+ row = [
+ pid,
+ info["mem"],
+ info["cpu"],
+ format_local_time(info["create_time"]),
+ ]
+ else:
+ row = [pid, "-", "-", "-"]
+ workers.append(row)
+
+ if workers:
+ workers_info = tabulate(
+ workers, headers=["PID", "MEM %", "CPU %", "started"], tablefmt="simple"
+ )
+ else:
+ workers_info = (
+ "--> No workers are running. Use `verdi daemon incr` to start some!\n"
+ )
+
+ start_time = format_local_time(daemon_response["info"]["create_time"])
+ echo.echo(
+ f'Daemon is running as PID {daemon_response["info"]["pid"]} since {start_time}\n'
+ f"Active workers [{len(workers)}]:\n{workers_info}\n"
+ "Use `verdi daemon [incr | decr] [num]` to increase / decrease the number of workers"
+ )
+
+ if not all(daemons_running):
+ sys.exit(3)
diff --git a/aiida_workgraph/engine/launch.py b/aiida_workgraph/engine/launch.py
new file mode 100644
index 00000000..25b8c1e8
--- /dev/null
+++ b/aiida_workgraph/engine/launch.py
@@ -0,0 +1,179 @@
+from __future__ import annotations
+
+import time
+import typing as t
+
+from aiida.common import InvalidOperation
+from aiida.common.log import AIIDA_LOGGER
+from aiida.manage import manager
+from aiida.orm import ProcessNode
+
+from aiida.engine.processes.builder import ProcessBuilder
+from aiida.engine.processes.functions import get_stack_size
+from aiida.engine.processes.process import Process
+from aiida.engine.utils import prepare_inputs
+from .utils import instantiate_process
+
+import signal
+import sys
+
+from aiida.manage import get_manager
+
+__all__ = ("run_get_node", "submit")
+
+TYPE_RUN_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder]
+# run can also be process function, but it is not clear what type this should be
+TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder]
+LOGGER = AIIDA_LOGGER.getChild("engine.launch")
+
+
+"""
+Note: I modified the run_get_node and submit functions to include the parent_pid argument.
+This is necessary for keeping track of the provenance of the processes.
+
+"""
+
+
+def run_get_node(
+ process_class, *args, **kwargs
+) -> tuple[dict[str, t.Any] | None, "ProcessNode"]:
+ """Run the FunctionProcess with the supplied inputs in a local runner.
+ :param args: input arguments to construct the FunctionProcess
+ :param kwargs: input keyword arguments to construct the FunctionProcess
+ :return: tuple of the outputs of the process and the process node
+ """
+ parent_pid = kwargs.pop("parent_pid", None)
+ frame_delta = 1000
+ frame_count = get_stack_size()
+ stack_limit = sys.getrecursionlimit()
+ LOGGER.info(
+ "Executing process function, current stack status: %d frames of %d",
+ frame_count,
+ stack_limit,
+ )
+ # If the current frame count is more than 80% of the stack limit, or comes within 200 frames, increase the
+ # stack limit by ``frame_delta``.
+ if frame_count > min(0.8 * stack_limit, stack_limit - 200):
+ LOGGER.warning(
+ "Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d",
+ frame_count,
+ stack_limit,
+ frame_delta,
+ )
+ sys.setrecursionlimit(stack_limit + frame_delta)
+ manager = get_manager()
+ runner = manager.get_runner()
+ inputs = process_class.create_inputs(*args, **kwargs)
+ # Remove all the known inputs from the kwargs
+ for port in process_class.spec().inputs:
+ kwargs.pop(port, None)
+ # If any kwargs remain, the spec should be dynamic, so we raise if it isn't
+ if kwargs and not process_class.spec().inputs.dynamic:
+ raise ValueError(
+ f"{function.__name__} does not support these kwargs: {kwargs.keys()}"
+ )
+ process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid)
+ # Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner.
+ # Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown
+ current_runner = manager.get_runner()
+ original_handler = None
+ kill_signal = signal.SIGINT
+ if not current_runner.is_daemon_runner:
+
+ def kill_process(_num, _frame):
+ """Send the kill signal to the process in the current scope."""
+ LOGGER.critical(
+ "runner received interrupt, killing process %s", process.pid
+ )
+ result = process.kill(
+ msg="Process was killed because the runner received an interrupt"
+ )
+ return result
+
+ # Store the current handler on the signal such that it can be restored after process has terminated
+ original_handler = signal.getsignal(kill_signal)
+ signal.signal(kill_signal, kill_process)
+ try:
+ result = process.execute()
+ finally:
+ # If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset
+ if original_handler:
+ signal.signal(signal.SIGINT, original_handler)
+ store_provenance = inputs.get("metadata", {}).get("store_provenance", True)
+ if not store_provenance:
+ process.node._storable = False
+ process.node._unstorable_message = (
+ "cannot store node because it was run with `store_provenance=False`"
+ )
+ return result, process.node
+
+
+def submit(
+ process: TYPE_SUBMIT_PROCESS,
+ inputs: dict[str, t.Any] | None = None,
+ *,
+ wait: bool = False,
+ wait_interval: int = 5,
+ parent_pid: int | None = None,
+ runner: "Runner" | None = None,
+ **kwargs: t.Any,
+) -> ProcessNode:
+ """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter.
+
+ .. warning: this should not be used within another process. Instead, there one should use the ``submit`` method of
+ the wrapping process itself, i.e. use ``self.submit``.
+
+ .. warning: submission of processes requires ``store_provenance=True``.
+
+ :param process: the process class, instance or builder to submit
+ :param inputs: the input dictionary to be passed to the process
+ :param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which
+ point the function returns the calculation node.
+ :param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``.
+ :param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument.
+ :return: the calculation node of the process
+ """
+ inputs = prepare_inputs(inputs, **kwargs)
+
+ # Submitting from within another process requires ``self.submit``` unless it is a work function, in which case the
+ # current process in the scope should be an instance of ``FunctionProcess``.
+ # if is_process_scoped() and not isinstance(Process.current(), FunctionProcess):
+ # raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead')
+
+ if not runner:
+ runner = manager.get_manager().get_runner()
+ assert runner.persister is not None, "runner does not have a persister"
+ assert runner.controller is not None, "runner does not have a controller"
+
+ process_inited = instantiate_process(
+ runner, process, parent_pid=parent_pid, **inputs
+ )
+
+ # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this
+ # instead of raising, because in this way the user does not have to change the launcher when testing. The same goes
+ # for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation.
+ if process_inited.metadata.get("dry_run", False) or "remote_folder" in inputs:
+ _, node = run_get_node(process_inited)
+ return node
+
+ if not process_inited.metadata.store_provenance:
+ raise InvalidOperation("cannot submit a process with `store_provenance=False`")
+
+ runner.persister.save_checkpoint(process_inited)
+ process_inited.close()
+
+ # Do not wait for the future's result, because in the case of a single worker this would cock-block itself
+ runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True)
+ node = process_inited.node
+
+ if not wait:
+ return node
+
+ while not node.is_terminated:
+ LOGGER.report(
+ f"Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. "
+ f"Waiting for {wait_interval} seconds."
+ )
+ time.sleep(wait_interval)
+
+ return node
diff --git a/aiida_workgraph/engine/override.py b/aiida_workgraph/engine/override.py
new file mode 100644
index 00000000..020be102
--- /dev/null
+++ b/aiida_workgraph/engine/override.py
@@ -0,0 +1,71 @@
+from plumpy.process_comms import RemoteProcessThreadController
+from typing import Any, Optional
+
+"""
+Note: I modified the the create_daemon_runner function and RemoteProcessThreadController
+to include the queue_name argument.
+
+"""
+
+
+def create_daemon_runner(
+ manager, queue_name: str = None, loop: Optional["asyncio.AbstractEventLoop"] = None
+) -> "Runner":
+ """Create and return a new daemon runner.
+ This is used by workers when the daemon is running and in testing.
+ :param loop: the (optional) asyncio event loop to use
+ :return: a runner configured to work in the daemon configuration
+ """
+ from plumpy.persistence import LoadSaveContext
+ from aiida.engine import persistence
+ from aiida.engine.processes.launcher import ProcessLauncher
+ from plumpy.communications import convert_to_comm
+
+ runner = manager.create_runner(broker_submit=True, loop=loop)
+ runner_loop = runner.loop
+ # Listen for incoming launch requests
+ task_receiver = ProcessLauncher(
+ loop=runner_loop,
+ persister=manager.get_persister(),
+ load_context=LoadSaveContext(runner=runner),
+ loader=persistence.get_object_loader(),
+ )
+
+ def callback(_comm, msg):
+ print("Received message: {}".format(msg))
+ import asyncio
+
+ asyncio.run(task_receiver(_comm, msg))
+ print("task_receiver._continue done")
+ return True
+
+ assert runner.communicator is not None, "communicator not set for runner"
+ if queue_name is not None:
+ print("queue_name: {}".format(queue_name))
+ queue = runner.communicator._communicator.task_queue(
+ queue_name, prefetch_count=1
+ )
+ # queue.add_task_subscriber(callback)
+ # important to convert the callback
+ converted = convert_to_comm(task_receiver, runner.communicator._loop)
+ queue.add_task_subscriber(converted)
+ else:
+ runner.communicator.add_task_subscriber(task_receiver)
+ return runner
+
+
+class ControllerWithQueueName(RemoteProcessThreadController):
+ def __init__(self, queue_name: str, **kwargs):
+ super().__init__(**kwargs)
+ self.queue_name = queue_name
+
+ def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]:
+ """
+ Send a task to be performed using the communicator
+
+ :param message: the task message
+ :param no_reply: if True, this call will be fire-and-forget, i.e. no return value
+ :return: the response from the remote side (if no_reply=False)
+ """
+ queue = self._communicator.task_queue(self.queue_name)
+ return queue.task_send(message, no_reply=no_reply)
diff --git a/aiida_workgraph/engine/scheduler/__init__.py b/aiida_workgraph/engine/scheduler/__init__.py
new file mode 100644
index 00000000..95a95abf
--- /dev/null
+++ b/aiida_workgraph/engine/scheduler/__init__.py
@@ -0,0 +1,3 @@
+from .scheduler import WorkGraphScheduler
+
+__all__ = ("WorkGraphScheduler",)
diff --git a/aiida_workgraph/engine/scheduler/client.py b/aiida_workgraph/engine/scheduler/client.py
new file mode 100644
index 00000000..4f5a8212
--- /dev/null
+++ b/aiida_workgraph/engine/scheduler/client.py
@@ -0,0 +1,324 @@
+from aiida.engine.daemon.client import DaemonClient
+import shutil
+from aiida.manage.manager import get_manager
+from aiida.common.exceptions import ConfigurationError
+import os
+from typing import Optional
+from aiida.common.log import AIIDA_LOGGER
+from typing import List
+
+WORKGRAPH_BIN = shutil.which("workgraph")
+LOGGER = AIIDA_LOGGER.getChild("engine.launch")
+
+
+class SchedulerClient(DaemonClient):
+ """Client for interacting with the scheduler daemon."""
+
+ _DAEMON_NAME = "scheduler-{name}"
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ @property
+ def _workgraph_bin(self) -> str:
+ """Return the absolute path to the ``verdi`` binary.
+
+ :raises ConfigurationError: If the path to ``verdi`` could not be found
+ """
+ if WORKGRAPH_BIN is None:
+ raise ConfigurationError(
+ "Unable to find 'verdi' in the path. Make sure that you are working "
+ "in a virtual environment, or that at least the 'verdi' executable is on the PATH"
+ )
+
+ return WORKGRAPH_BIN
+
+ @property
+ def filepaths(self):
+ """Return the filepaths used by this profile.
+
+ :return: a dictionary of filepaths
+ """
+ from aiida.manage.configuration.settings import DAEMON_DIR, DAEMON_LOG_DIR
+
+ return {
+ "circus": {
+ "log": str(
+ DAEMON_LOG_DIR / f"circus-scheduler-{self.profile.name}.log"
+ ),
+ "pid": str(DAEMON_DIR / f"circus-scheduler-{self.profile.name}.pid"),
+ "port": str(DAEMON_DIR / f"circus-scheduler-{self.profile.name}.port"),
+ "socket": {
+ "file": str(
+ DAEMON_DIR / f"circus-scheduler-{self.profile.name}.sockets"
+ ),
+ "controller": "circus.c.sock",
+ "pubsub": "circus.p.sock",
+ "stats": "circus.s.sock",
+ },
+ },
+ "daemon": {
+ "log": str(DAEMON_LOG_DIR / f"aiida-scheduler-{self.profile.name}.log"),
+ "pid": str(DAEMON_DIR / f"aiida-scheduler-{self.profile.name}.pid"),
+ },
+ }
+
+ @property
+ def circus_log_file(self) -> str:
+ return self.filepaths["circus"]["log"]
+
+ @property
+ def circus_pid_file(self) -> str:
+ return self.filepaths["circus"]["pid"]
+
+ @property
+ def circus_port_file(self) -> str:
+ return self.filepaths["circus"]["port"]
+
+ @property
+ def circus_socket_file(self) -> str:
+ return self.filepaths["circus"]["socket"]["file"]
+
+ @property
+ def circus_socket_endpoints(self) -> dict[str, str]:
+ return self.filepaths["circus"]["socket"]
+
+ @property
+ def daemon_log_file(self) -> str:
+ return self.filepaths["daemon"]["log"]
+
+ @property
+ def daemon_pid_file(self) -> str:
+ return self.filepaths["daemon"]["pid"]
+
+ def cmd_start_daemon(
+ self, number_workers: int = 1, foreground: bool = False
+ ) -> list[str]:
+ """Return the command to start the daemon.
+
+ :param number_workers: Number of daemon workers to start.
+ :param foreground: Whether to launch the subprocess in the background or not.
+ """
+ command = [
+ self._workgraph_bin,
+ "-p",
+ self.profile.name,
+ "scheduler",
+ "start-circus",
+ str(number_workers),
+ ]
+
+ if foreground:
+ command.append("--foreground")
+
+ return command
+
+ @property
+ def cmd_start_daemon_worker(self) -> list[str]:
+ """Return the command to start a daemon worker process."""
+ return [self._workgraph_bin, "-p", self.profile.name, "scheduler", "worker"]
+
+ def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> None:
+ """Start the daemon.
+
+ .. warning:: This will daemonize the current process and put it in the background. It is most likely not what
+ you want to call if you want to start the daemon from the Python API. Instead you probably will want to use
+ the :meth:`aiida.engine.daemon.client.DaemonClient.start_daemon` function instead.
+
+ :param number_workers: Number of daemon workers to start.
+ :param foreground: Whether to launch the subprocess in the background or not.
+ """
+ from circus import get_arbiter
+ from circus import logger as circus_logger
+ from circus.circusd import daemonize
+ from circus.pidfile import Pidfile
+ from circus.util import check_future_exception_and_log, configure_logger
+
+ loglevel = self.loglevel
+ logoutput = "-"
+
+ if not foreground:
+ logoutput = self.circus_log_file
+
+ arbiter_config = {
+ "controller": self.get_controller_endpoint(),
+ "pubsub_endpoint": self.get_pubsub_endpoint(),
+ "stats_endpoint": self.get_stats_endpoint(),
+ "logoutput": logoutput,
+ "loglevel": loglevel,
+ "debug": False,
+ "statsd": True,
+ "pidfile": self.circus_pid_file,
+ "watchers": [
+ {
+ "cmd": " ".join(self.cmd_start_daemon_worker),
+ "name": self.daemon_name,
+ "numprocesses": number_workers,
+ "virtualenv": self.virtualenv,
+ "copy_env": True,
+ "stdout_stream": {
+ "class": "FileStream",
+ "filename": self.daemon_log_file,
+ },
+ "stderr_stream": {
+ "class": "FileStream",
+ "filename": self.daemon_log_file,
+ },
+ "env": self.get_env(),
+ }
+ ],
+ }
+
+ if not foreground:
+ daemonize()
+
+ arbiter = get_arbiter(**arbiter_config)
+ pidfile = Pidfile(arbiter.pidfile)
+ pidfile.create(os.getpid())
+
+ # Configure the logger
+ loggerconfig = arbiter.loggerconfig or None
+ configure_logger(circus_logger, loglevel, logoutput, loggerconfig)
+
+ # Main loop
+ should_restart = True
+
+ while should_restart:
+ try:
+ future = arbiter.start()
+ should_restart = False
+ if check_future_exception_and_log(future) is None:
+ should_restart = arbiter._restarting
+ except Exception as exception:
+ # Emergency stop
+ arbiter.loop.run_sync(arbiter._emergency_stop)
+ raise exception
+ except KeyboardInterrupt:
+ pass
+ finally:
+ arbiter = None
+ if pidfile is not None:
+ pidfile.unlink()
+
+
+def get_scheduler_client(profile_name: Optional[str] = None) -> "SchedulerClient":
+ """Return the daemon client for the given profile or the current profile if not specified.
+
+ :param profile_name: Optional profile name to load.
+ :return: The daemon client.
+
+ :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found.
+ :raises aiida.common.ProfileConfigurationError: if the given profile does not exist.
+ """
+ profile = get_manager().load_profile(profile_name)
+ return SchedulerClient(profile)
+
+
+def get_scheduler() -> List[int]:
+ from aiida.orm import QueryBuilder
+ from aiida_workgraph.engine.scheduler import WorkGraphScheduler
+
+ qb = QueryBuilder()
+ projections = ["id"]
+ filters = {
+ "or": [
+ {"attributes.sealed": False},
+ {"attributes": {"!has_key": "sealed"}},
+ ]
+ }
+ qb.append(WorkGraphScheduler, filters=filters, project=projections, tag="process")
+ results = qb.all()
+ pks = [r[0] for r in results]
+ return pks
+
+
+def start_scheduler_worker(foreground: bool = False) -> None:
+ """Start a scheduler worker for the currently configured profile.
+
+ :param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to
+ write to the scheduler log file.
+ """
+ import asyncio
+ import signal
+ import sys
+ from aiida_workgraph.engine.scheduler.client import get_scheduler_client
+ from aiida_workgraph.engine.override import create_daemon_runner
+
+ from aiida.common.log import configure_logging
+ from aiida.manage import get_config_option
+ from aiida.engine.daemon.worker import shutdown_worker
+
+ daemon_client = get_scheduler_client()
+ configure_logging(
+ daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file
+ )
+
+ LOGGER.debug(f"sys.executable: {sys.executable}")
+ LOGGER.debug(f"sys.path: {sys.path}")
+
+ try:
+ manager = get_manager()
+ runner = create_daemon_runner(manager, queue_name="scheduler_queue")
+ except Exception:
+ LOGGER.exception("daemon worker failed to start")
+ raise
+
+ if isinstance(rlimit := get_config_option("daemon.recursion_limit"), int):
+ LOGGER.info("Setting maximum recursion limit of daemon worker to %s", rlimit)
+ sys.setrecursionlimit(rlimit)
+
+ signals = (signal.SIGTERM, signal.SIGINT)
+ for s in signals:
+ # https://github.com/python/mypy/issues/12557
+ runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) # type: ignore[misc]
+
+ try:
+ LOGGER.info("Starting a daemon worker")
+ runner.start()
+ except SystemError as exception:
+ LOGGER.info("Received a SystemError: %s", exception)
+ runner.close()
+
+ LOGGER.info("Daemon worker started")
+
+
+def start_scheduler_process(number: int = 1) -> None:
+ """Start or restart the specified number of scheduler processes."""
+ from aiida_workgraph.engine.scheduler import WorkGraphScheduler
+ from aiida_workgraph.engine.scheduler.client import get_scheduler
+ from aiida_workgraph.utils.control import create_scheduler_action
+ from aiida_workgraph.engine.utils import instantiate_process
+
+ try:
+ schedulers: List[int] = get_scheduler()
+ existing_schedulers_count = len(schedulers)
+ print(
+ "Found {} existing scheduler(s): {}".format(
+ existing_schedulers_count, " ".join([str(pk) for pk in schedulers])
+ )
+ )
+
+ count = 0
+
+ # Restart existing schedulers if they exceed the number to start
+ for pk in schedulers[:number]:
+ # When the runner stop, the runner does not ack back to rmq,
+ # so the msg is still in the queue, and the msg is not acked,
+ # we don't need send the msg to continue again
+ print(f"Scheduler with pk {pk} restart and running.")
+ count += 1
+ # not running
+ for pk in schedulers[number:]:
+ print(f"Scheduler with pk {pk} not running.")
+
+ # Start new schedulers if more are needed
+ runner = get_manager().get_runner()
+ for i in range(count, number):
+ process_inited = instantiate_process(runner, WorkGraphScheduler)
+ process_inited.runner.persister.save_checkpoint(process_inited)
+ process_inited.close()
+ create_scheduler_action(process_inited.node.pk)
+ print(f"Scheduler with pk {process_inited.node.pk} running.")
+
+ except Exception as e:
+ raise (f"An error occurred while starting schedulers: {e}")
diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py
new file mode 100644
index 00000000..77ba9ca9
--- /dev/null
+++ b/aiida_workgraph/engine/scheduler/scheduler.py
@@ -0,0 +1,1743 @@
+"""AiiDA workflow components: WorkGraph."""
+from __future__ import annotations
+
+import asyncio
+import collections.abc
+import functools
+import logging
+import typing as t
+
+from plumpy import process_comms
+from plumpy.persistence import auto_persist
+from plumpy.process_states import Continue, Wait, Finished, Running
+import kiwipy
+
+from aiida.common import exceptions
+from aiida.common.extendeddicts import AttributeDict
+from aiida.common.lang import override
+from aiida import orm
+from aiida.orm import load_node, Node, ProcessNode, WorkChainNode
+from aiida.orm.utils.serialize import deserialize_unsafe, serialize
+
+from aiida.engine.processes.exit_code import ExitCode
+from aiida.engine.processes.process import Process
+
+from aiida.engine.processes.workchains.awaitable import (
+ Awaitable,
+ AwaitableAction,
+ AwaitableTarget,
+ construct_awaitable,
+)
+from aiida.engine.processes.workchains.workchain import Protect, WorkChainSpec
+from aiida_workgraph.utils import create_and_pause_process
+from aiida_workgraph.task import Task
+from aiida_workgraph.utils import get_nested_dict, update_nested_dict
+from aiida_workgraph.executors.monitors import monitor
+from aiida.common.log import LOG_LEVEL_REPORT
+
+if t.TYPE_CHECKING:
+ from aiida.engine.runners import Runner # pylint: disable=unused-import
+
+__all__ = "WorkGraph"
+
+
+MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}. Cannot launch the job: {}."
+
+
+@auto_persist("_awaitables")
+class WorkGraphScheduler(Process, metaclass=Protect):
+ """The `WorkGraph` class is used to construct workflows in AiiDA."""
+
+ # used to create a process node that represents what happened in this process.
+ _node_class = WorkChainNode
+ _spec_class = WorkChainSpec
+ _CONTEXT = "CONTEXT"
+
+ def __init__(
+ self,
+ inputs: dict | None = None,
+ logger: logging.Logger | None = None,
+ runner: "Runner" | None = None,
+ enable_persistence: bool = True,
+ **kwargs: t.Any,
+ ) -> None:
+ """Construct a WorkGraph instance.
+
+ :param inputs: work graph inputs
+ :param logger: aiida logger
+ :param runner: work graph runner
+ :param enable_persistence: whether to persist this work graph
+
+ """
+
+ super().__init__(
+ inputs, logger, runner, enable_persistence=enable_persistence, **kwargs
+ )
+
+ self._awaitables: list[Awaitable] = []
+ self._context = AttributeDict()
+
+ @classmethod
+ def define(cls, spec: WorkChainSpec) -> None:
+ super().define(spec)
+ spec.input("input_file", valid_type=orm.SinglefileData, required=False)
+ spec.input_namespace(
+ "wg", dynamic=True, required=False, help="WorkGraph inputs"
+ )
+ spec.input_namespace("input_tasks", dynamic=True, required=False)
+ spec.exit_code(2, "ERROR_SUBPROCESS", message="A subprocess has failed.")
+
+ spec.outputs.dynamic = True
+
+ spec.output_namespace("new_data", dynamic=True)
+ spec.output(
+ "execution_count",
+ valid_type=orm.Int,
+ required=False,
+ help="The number of time the WorkGraph runs.",
+ )
+ #
+ spec.exit_code(
+ 201, "UNKNOWN_MESSAGE_TYPE", message="The message type is unknown."
+ )
+ spec.exit_code(202, "UNKNOWN_TASK_TYPE", message="The task type is unknown.")
+ #
+ spec.exit_code(
+ 301,
+ "OUTPUS_NOT_MATCH_RESULTS",
+ message="The outputs of the process do not match the results.",
+ )
+ spec.exit_code(
+ 302,
+ "TASK_FAILED",
+ message="Some of the tasks failed.",
+ )
+ spec.exit_code(
+ 303,
+ "TASK_NON_ZERO_EXIT_STATUS",
+ message="Some of the tasks exited with non-zero status.",
+ )
+
+ @property
+ def ctx(self) -> AttributeDict:
+ """Get the context."""
+ return self._context
+
+ @override
+ def save_instance_state(
+ self, out_state: t.Dict[str, t.Any], save_context: t.Any
+ ) -> None:
+ """Save instance state.
+
+ :param out_state: state to save in
+
+ :param save_context:
+ :type save_context: :class:`!plumpy.persistence.LoadSaveContext`
+
+ """
+ super().save_instance_state(out_state, save_context)
+ # Save the context
+ out_state[self._CONTEXT] = self.ctx
+
+ @override
+ def load_instance_state(
+ self, saved_state: t.Dict[str, t.Any], load_context: t.Any
+ ) -> None:
+ super().load_instance_state(saved_state, load_context)
+ # Load the context
+ self._context = saved_state[self._CONTEXT]
+ self._temp = {"awaitables": {}}
+
+ self.set_logger(self.node.logger)
+ self.add_workgraph_subsriber()
+
+ if self._awaitables:
+ # For the "ascyncio.tasks.Task" awaitable, because there are only in-memory,
+ # we need to reset the tasks and so that they can be re-run again.
+ should_resume = False
+ for awaitable in self._awaitables:
+ if awaitable.target == "asyncio.tasks.Task":
+ self._resolve_awaitable(awaitable, None)
+ self.report(f"reset awaitable task: {awaitable.key}")
+ self.reset_task(awaitable.key)
+ should_resume = True
+ if should_resume:
+ self._update_process_status()
+ self.resume()
+ # For other awaitables, because they exist in the db, we only need to re-register the callbacks
+ self._action_awaitables()
+ # load checkpoint
+ launched_workgraphs = self.node.base.extras.get("_launched_workgraphs", [])
+ for pk in launched_workgraphs:
+ print("load workgraph: ", pk)
+ node = load_node(pk)
+ wgdata = node.base.extras.get("_checkpoint", None)
+ if wgdata is None:
+ self.launch_workgraph(pk)
+ else:
+ self.ctx._workgraph[pk] = deserialize_unsafe(wgdata)
+ print("continue workgraph: ", pk)
+ self.continue_workgraph(pk)
+
+ def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]:
+ """
+ Returns a reference to a sub-dictionary of the context and the last key,
+ after resolving a potentially segmented key where required sub-dictionaries are created as needed.
+
+ :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary
+ """
+ ctx = self.ctx
+ ctx_path = key.split(".")
+
+ for index, path in enumerate(ctx_path[:-1]):
+ try:
+ ctx = ctx[path]
+ except KeyError: # see below why this is the only exception we have to catch here
+ ctx[
+ path
+ ] = AttributeDict() # create the sub-dict and update the context
+ ctx = ctx[path]
+ continue
+
+ # Notes:
+ # * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking
+ # * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables
+ # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself
+ # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable
+ # would be an AttributeDict we can append things to it since the order of tasks is maintained.
+ if type(ctx) != AttributeDict: # pylint: disable=C0123
+ raise ValueError(
+ f"Can not update the context for key `{key}`: "
+ f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index + 1])}`, expected AttributeDict'
+ )
+
+ return ctx, ctx_path[-1]
+
+ def _insert_awaitable(self, awaitable: Awaitable) -> None:
+ """Insert an awaitable that should be terminated before before continuing to the next step.
+
+ :param awaitable: the thing to await
+ """
+ ctx, key = self._resolve_nested_context(awaitable.key)
+
+ # Already assign the awaitable itself to the location in the context container where it is supposed to end up
+ # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the
+ # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the
+ # awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value.
+ if awaitable.action == AwaitableAction.ASSIGN:
+ ctx[key] = awaitable
+ elif awaitable.action == AwaitableAction.APPEND:
+ ctx.setdefault(key, []).append(awaitable)
+ else:
+ raise AssertionError(f"Unsupported awaitable action: {awaitable.action}")
+
+ # Register the callback to be called when the awaitable is resolved
+ self._add_callback_to_awaitable(awaitable)
+ self._awaitables.append(
+ awaitable
+ ) # add only if everything went ok, otherwise we end up in an inconsistent state
+ self._update_process_status()
+
+ def _resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None:
+ """Resolve an awaitable.
+
+ Precondition: must be an awaitable that was previously inserted.
+
+ :param awaitable: the awaitable to resolve
+ :param value: the value to assign to the awaitable
+ """
+ ctx, key = self._resolve_nested_context(awaitable.key)
+
+ if awaitable.action == AwaitableAction.ASSIGN:
+ ctx[key] = value
+ elif awaitable.action == AwaitableAction.APPEND:
+ # Find the same awaitable inserted in the context
+ container = ctx[key]
+ for index, placeholder in enumerate(container):
+ if (
+ isinstance(placeholder, Awaitable)
+ and placeholder.pk == awaitable.pk
+ ):
+ container[index] = value
+ break
+ else:
+ raise AssertionError(
+ f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`"
+ )
+ else:
+ raise AssertionError(f"Unsupported awaitable action: {awaitable.action}")
+
+ awaitable.resolved = True
+ # remove awaitabble from the list
+ self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk]
+
+ if not self.has_terminated():
+ # the process may be terminated, for example, if the process was killed or excepted
+ # then we should not try to update it
+ self._update_process_status()
+
+ @Protect.final
+ def to_context(self, **kwargs: Awaitable | ProcessNode) -> None:
+ """Add a dictionary of awaitables to the context.
+
+ This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will
+ assign a certain value to the corresponding key in the context of the work graph.
+ """
+ for key, value in kwargs.items():
+ awaitable = construct_awaitable(value)
+ awaitable.key = key
+ awaitable.workgraph_pk = value.workgraph_pk
+ self._insert_awaitable(awaitable)
+
+ def _update_process_status(self) -> None:
+ """Set the process status with a message accounting the current sub processes that we are waiting for."""
+ if self._awaitables:
+ status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}"
+ self.node.set_process_status(status)
+ else:
+ self.node.set_process_status(None)
+
+ @override
+ def run(self) -> t.Any:
+ self.setup()
+ return self._do_step()
+
+ def _do_step(self) -> t.Any:
+ """Execute the next step in the workgraph and return the result.
+
+ If any awaitables were created, the process will enter in the Wait state,
+ otherwise it will go to Continue.
+ """
+ # we will not remove the awaitables here,
+ # we resume the workgraph in the callback function even
+ # there are some awaitables left
+ # self._awaitables = []
+
+ if self._awaitables:
+ return Wait(self._do_step, "Waiting before next step")
+
+ return Continue(self._do_step)
+
+ def _store_nodes(self, data: t.Any) -> None:
+ """Recurse through a data structure and store any unstored nodes that are found along the way
+
+ :param data: a data structure potentially containing unstored nodes
+ """
+ if isinstance(data, Node) and not data.is_stored:
+ data.store()
+ elif isinstance(data, collections.abc.Mapping):
+ for _, value in data.items():
+ self._store_nodes(value)
+ elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
+ for value in data:
+ self._store_nodes(value)
+
+ @override
+ @Protect.final
+ def on_exiting(self) -> None:
+ """Ensure that any unstored nodes in the context are stored, before the state is exited
+
+ After the state is exited the next state will be entered and if persistence is enabled, a checkpoint will
+ be saved. If the context contains unstored nodes, the serialization necessary for checkpointing will fail.
+ """
+ super().on_exiting()
+ try:
+ self._store_nodes(self.ctx)
+ except Exception: # pylint: disable=broad-except
+ # An uncaught exception here will have bizarre and disastrous consequences
+ self.logger.exception("exception in _store_nodes called in on_exiting")
+
+ @Protect.final
+ def on_wait(self, awaitables: t.Sequence[t.Awaitable]):
+ """Entering the WAITING state."""
+ super().on_wait(awaitables)
+ if self._awaitables:
+ self._action_awaitables()
+ self.report("Process status: {}".format(self.node.process_status))
+ else:
+ self.call_soon(self.resume)
+
+ def _action_awaitables(self) -> None:
+ """Handle the awaitables that are currently registered with the work chain.
+
+ Depending on the class type of the awaitable's target a different callback
+ function will be bound with the awaitable and the runner will be asked to
+ call it when the target is completed
+ """
+ for awaitable in self._awaitables:
+ pk = awaitable.workgraph_pk
+ # if the waitable already has a callback, skip
+ if awaitable.pk in self.ctx._workgraph[pk]["_awaitable_actions"]:
+ continue
+ self._add_callback_to_awaitable(awaitable)
+
+ def _add_callback_to_awaitable(self, awaitable: Awaitable) -> None:
+ """Add a callback to the awaitable."""
+ pk = awaitable.workgraph_pk
+ if awaitable.target == AwaitableTarget.PROCESS:
+ callback = functools.partial(
+ self.call_soon, self._on_awaitable_finished, awaitable
+ )
+ self.runner.call_on_process_finish(awaitable.pk, callback)
+ self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk)
+ elif awaitable.target == "asyncio.tasks.Task":
+ # this is a awaitable task, the callback function is already set
+ self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk)
+ else:
+ assert f"invalid awaitable target '{awaitable.target}'"
+
+ def _on_awaitable_finished(self, awaitable: Awaitable) -> None:
+ """Callback function, for when an awaitable process instance is completed.
+
+ The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all
+ awaitables have been dealt with, the work chain process is resumed.
+
+ :param awaitable: an Awaitable instance
+ """
+ import time
+
+ print(f"Awaitable {awaitable.key} finished.")
+ self.logger.debug(f"Awaitable {awaitable.key} finished.")
+ pk = awaitable.workgraph_pk
+ node = load_node(awaitable.pk)
+ # there is a bug in aiida.engine.process, it send the msg before setting the process state and outputs
+ # so we need to wait for a while
+ # TODO make a PR to fix the aiida-core, the `super().on_entered()` should be
+ # called after setting the process state
+ tstart = time.time()
+ while not node.is_finished and time.time() - tstart < 5:
+ time.sleep(0.1)
+
+ if isinstance(awaitable.pk, int):
+ self.logger.info(
+ "received callback that awaitable with key {} and pk {} has terminated".format(
+ awaitable.key, awaitable.pk
+ )
+ )
+ try:
+ node = load_node(awaitable.pk)
+ except (exceptions.MultipleObjectsError, exceptions.NotExistent):
+ raise ValueError(
+ f"provided pk<{awaitable.pk}> could not be resolved to a valid Node instance"
+ )
+
+ if awaitable.outputs:
+ value = {
+ entry.link_label: entry.node
+ for entry in node.base.links.get_outgoing()
+ }
+ else:
+ value = node # type: ignore
+ else:
+ # In this case, the pk and key are the same.
+ self.logger.info(
+ "received callback that awaitable {} has terminated".format(
+ awaitable.key
+ )
+ )
+ try:
+ # if awaitable is cancelled, the result is None
+ if awaitable.cancelled():
+ self.set_task_state_info(pk, awaitable.key, "state", "KILLED")
+ # set child tasks state to SKIPPED
+ self.set_tasks_state(
+ pk,
+ self.ctx._workgraph[pk]["_connectivity"]["child_node"][
+ awaitable.key
+ ],
+ "SKIPPED",
+ )
+ self.report(f"Task: {awaitable.key} cancelled.")
+ else:
+ results = awaitable.result()
+ self.set_normal_task_results(
+ awaitable.workgraph_pk, awaitable.key, results
+ )
+ except Exception as e:
+ self.logger.error(f"Error in awaitable {awaitable.key}: {e}")
+ self.set_task_state_info(pk, awaitable.key, "state", "FAILED")
+ # set child tasks state to SKIPPED
+ self.set_tasks_state(
+ pk,
+ self.ctx._workgraph[pk]["_connectivity"]["child_node"][
+ awaitable.key
+ ],
+ "SKIPPED",
+ )
+ self.report(f"Task: {awaitable.key} failed.")
+ self.run_error_handlers(pk, awaitable.key)
+ value = None
+
+ self._resolve_awaitable(awaitable, value)
+
+ # node finished, update the task state and result
+ # udpate the task state
+ print(f"Update task state: {awaitable.key}")
+ self.update_task_state(awaitable.workgraph_pk, awaitable.key)
+ # try to resume the workgraph, if the workgraph is already resumed
+ # by other awaitable, this will not work
+ # try:
+ # self.resume()
+ # except Exception as e:
+ # print(e)
+
+ def _build_process_label(self) -> str:
+ """Use the workgraph name as the process label."""
+ return "Scheduler"
+
+ def on_create(self) -> None:
+ """Called when a Process is created."""
+
+ super().on_create()
+ self.node.label = "Scheduler"
+
+ def setup(self) -> None:
+ """Setup the variables in the context."""
+ # track if the awaitable callback is added to the runner
+
+ self.ctx.launched_workgraphs = []
+ self.ctx._workgraph = {}
+ self.ctx._max_number_awaitables = 10000
+ awaitable = Awaitable(
+ **{
+ "workgraph_pk": self.node.pk,
+ "pk": "scheduler",
+ "action": AwaitableAction.ASSIGN,
+ "target": "scheduler",
+ "outputs": False,
+ }
+ )
+ self.ctx._workgraph[self.node.pk] = {"_awaitable_actions": []}
+ self.to_context(scheduler=awaitable)
+ # self.ctx._msgs = []
+ # self.ctx._workgraph[pk]["_execution_count"] = {}
+ # data not to be persisted, because they are not serializable
+ self._temp = {"awaitables": {}}
+
+ def launch_workgraph(self, pk: str) -> None:
+ """Launch the workgraph."""
+ # create the workgraph process
+ self.report(f"Launch workgraph: {pk}")
+ # append the pk to the self.node.base.extras
+ self.ctx.launched_workgraphs.append(pk)
+ self.node.base.extras.set("_launched_workgraphs", self.ctx.launched_workgraphs)
+ self.init_ctx_workgraph(pk)
+ self.ctx._workgraph[pk]["_node"].set_process_state(Running.LABEL)
+ self.init_task_results(pk)
+ self.continue_workgraph(pk)
+
+ def init_ctx_workgraph(self, pk: int) -> None:
+ """Init the context from the workgraph data."""
+ from aiida_workgraph.utils import update_nested_dict
+
+ # read the latest workgraph data
+ wgdata, node = self.read_wgdata_from_base(pk)
+ self.ctx._workgraph[pk] = {
+ "_awaitable_actions": {},
+ "_new_data": {},
+ "_execution_count": 1,
+ "_executed_tasks": [],
+ "_count": 0,
+ "_context": {},
+ "_node": node,
+ }
+ for key, value in wgdata["context"].items():
+ key = key.replace("__", ".")
+ update_nested_dict(self.ctx._workgraph[pk], key, value)
+ # set up the workgraph
+ self.setup_ctx_workgraph(pk, wgdata)
+
+ def setup_ctx_workgraph(self, pk: int, wgdata: t.Dict[str, t.Any]) -> None:
+ """setup the workgraph in the context."""
+ import cloudpickle as pickle
+
+ self.ctx._workgraph[pk]["_tasks"] = wgdata.pop("tasks")
+ self.ctx._workgraph[pk]["_links"] = wgdata.pop("links")
+ self.ctx._workgraph[pk]["_connectivity"] = wgdata.pop("connectivity")
+ self.ctx._workgraph[pk]["_ctrl_links"] = wgdata.pop("ctrl_links")
+ self.ctx._workgraph[pk]["_error_handlers"] = pickle.loads(
+ wgdata.pop("error_handlers")
+ )
+ self.ctx._workgraph[pk]["_metadata"] = wgdata.pop("metadata")
+ self.ctx._workgraph[pk]["_workgraph"] = wgdata
+ self.ctx._workgraph[pk]["_awaitable_actions"] = []
+
+ def read_wgdata_from_base(self, pk: int) -> t.Dict[str, t.Any]:
+ """Read workgraph data from base.extras."""
+ from aiida_workgraph.orm.function_data import PickledLocalFunction
+
+ node = load_node(pk)
+
+ wgdata = node.base.extras.get("_workgraph")
+ for name, task in wgdata["tasks"].items():
+ wgdata["tasks"][name] = deserialize_unsafe(task)
+ for _, prop in wgdata["tasks"][name]["properties"].items():
+ if isinstance(prop["value"], PickledLocalFunction):
+ prop["value"] = prop["value"].value
+ wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"])
+ wgdata["context"] = deserialize_unsafe(wgdata["context"])
+
+ return wgdata, node
+
+ def update_workgraph_from_base(self, pk: int) -> None:
+ """Update the ctx from base.extras."""
+ wgdata, _ = self.read_wgdata_from_base()
+ for name, task in wgdata["tasks"].items():
+ task["results"] = self.ctx._workgraph[pk]["_tasks"][name].get("results")
+ self.setup_ctx_workgraph(pk, wgdata)
+
+ def get_task(self, name: str):
+ """Get task from the context."""
+ task = Task.from_dict(self.ctx._workgraph[pk]["_tasks"][name])
+ return task
+
+ def update_task(self, pk, task: Task):
+ """Update task in the context.
+ This is used in error handlers to update the task parameters."""
+ self.ctx._workgraph[pk]["_tasks"][task.name][
+ "properties"
+ ] = task.properties_to_dict()
+ self.reset_task(task.name)
+
+ def get_task_state_info(self, pk: int, name: str, key: str) -> str:
+ """Get task state info from ctx."""
+
+ value = self.ctx._workgraph[pk]["_tasks"][name].get(key, None)
+ if key == "process" and value is not None:
+ value = deserialize_unsafe(value)
+ return value
+
+ def set_task_state_info(self, pk: int, name: str, key: str, value: any) -> None:
+ """Set task state info to ctx and base.extras.
+ We task state to the base.extras, so that we can access outside the engine"""
+
+ if key == "process":
+ value = serialize(value)
+ self.ctx._workgraph[pk]["_node"].base.extras.set(
+ f"_task_{key}_{name}", value
+ )
+ else:
+ self.ctx._workgraph[pk]["_node"].base.extras.set(
+ f"_task_{key}_{name}", value
+ )
+ self.ctx._workgraph[pk]["_tasks"][name][key] = value
+
+ def init_task_results(self, pk) -> None:
+ """Init the task results."""
+ for name, task in self.ctx._workgraph[pk]["_tasks"].items():
+ if self.get_task_state_info(pk, name, "action").upper() == "RESET":
+ self.reset_task(pk, task["name"])
+ # only init the task results, and do not need to continue the workgraph
+ self.update_task_state(pk, name, continue_workgraph=False)
+
+ def apply_action(self, msg: dict) -> None:
+
+ if msg["catalog"] == "task":
+ self.apply_task_actions(msg)
+ else:
+ self.report(f"Unknow message type {msg}")
+
+ def apply_task_actions(self, msg: dict) -> None:
+ """Apply task actions to the workgraph."""
+ action = msg["action"]
+ tasks = msg["tasks"]
+ self.report(f"Action: {action}. {tasks}")
+ if action.upper() == "RESET":
+ for name in tasks:
+ self.reset_task(name)
+ elif action.upper() == "PAUSE":
+ for name in tasks:
+ self.pause_task(name)
+ elif action.upper() == "PLAY":
+ for name in tasks:
+ self.play_task(name)
+ elif action.upper() == "SKIP":
+ for name in tasks:
+ self.skip_task(name)
+ elif action.upper() == "KILL":
+ for name in tasks:
+ self.kill_task(name)
+ elif action.upper() == "LAUNCH_WORKGRAPH":
+ for pk in tasks:
+ self.launch_workgraph(pk)
+
+ def reset_task(
+ self,
+ pk: int,
+ name: str,
+ reset_process: bool = True,
+ recursive: bool = True,
+ reset_execution_count: bool = True,
+ ) -> None:
+ """Reset task state and remove it from the executed task.
+ If recursive is True, reset its child tasks."""
+
+ self.set_task_state_info(pk, name, "state", "PLANNED")
+ if reset_process:
+ self.set_task_state_info(pk, name, "process", None)
+ self.remove_executed_task(name)
+ # self.logger.debug(f"Task {name} action: RESET.")
+ # if the task is a while task, reset its child tasks
+ if (
+ self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["node_type"].upper()
+ == "WHILE"
+ ):
+ if reset_execution_count:
+ self.ctx._workgraph[pk]["_tasks"][name]["execution_count"] = 0
+ for child_task in self.ctx._workgraph[pk]["_tasks"][name]["children"]:
+ self.reset_task(child_task, reset_process=False, recursive=False)
+ if recursive:
+ # reset its child tasks
+ names = self.ctx._workgraph[pk]["_connectivity"]["child_node"][name]
+ for name in names:
+ self.reset_task(name, recursive=False)
+
+ def pause_task(self, name: str) -> None:
+ """Pause task."""
+ self.set_task_state_info(pk, name, "action", "PAUSE")
+ self.report(f"Task {name} action: PAUSE.")
+
+ def play_task(self, name: str) -> None:
+ """Play task."""
+ self.set_task_state_info(pk, name, "action", "")
+ self.report(f"Task {name} action: PLAY.")
+
+ def skip_task(self, name: str) -> None:
+ """Skip task."""
+ self.set_task_state_info(pk, name, "state", "SKIPPED")
+ self.report(f"Task {name} action: SKIP.")
+
+ def kill_task(self, pk, name: str) -> None:
+ """Kill task.
+ This is used to kill the awaitable and monitor task.
+ """
+ if self.get_task_state_info(pk, name, "state") in ["RUNNING"]:
+ if self.ctx._workgraph[pk]["_tasks"][name]["metadata"][
+ "node_type"
+ ].upper() in [
+ "AWAITABLE",
+ "MONITOR",
+ ]:
+ try:
+ self._temp["awaitables"][name].cancel()
+ self.set_task_state_info(pk, name, "state", "KILLED")
+ self.report(f"Task {name} action: KILLED.")
+ except Exception as e:
+ self.logger.error(f"Error in killing task {name}: {e}")
+
+ def report(self, msg, pk=None):
+ """Report the message."""
+ if pk:
+ self.ctx._workgraph[pk]["_node"].logger.log(LOG_LEVEL_REPORT, msg)
+ else:
+ super().report(msg)
+
+ def continue_workgraph(self, pk: int) -> None:
+ # if the workgraph is finished, skip
+ if pk not in self.ctx._workgraph:
+ # the workgraph is finished
+ return
+ is_finished, _ = self.is_workgraph_finished(pk)
+ if is_finished:
+ self.finalize_workgraph(pk)
+ # remove the workgraph from the context
+ del self.ctx._workgraph[pk]
+ self.ctx.launched_workgraphs.remove(pk)
+ self.node.base.extras.set(
+ "_launched_workgraphs", self.ctx.launched_workgraphs
+ )
+ return
+ self.report("Continue.", pk)
+ task_to_run = []
+ for name, task in self.ctx._workgraph[pk]["_tasks"].items():
+ # update task state
+ if (
+ self.get_task_state_info(pk, task["name"], "state")
+ in [
+ "CREATED",
+ "RUNNING",
+ "FINISHED",
+ "FAILED",
+ "SKIPPED",
+ ]
+ or name in self.ctx._workgraph[pk]["_executed_tasks"]
+ ):
+ continue
+ ready, _ = self.is_task_ready_to_run(pk, name)
+ if ready:
+ task_to_run.append(name)
+ #
+ self.report("tasks ready to run: {}".format(",".join(task_to_run)), pk)
+ if len(task_to_run) > 0:
+ self.run_tasks(pk, task_to_run)
+
+ def update_task_state(
+ self, pk: int, name: str, continue_workgraph: bool = True
+ ) -> None:
+ """Update task state when the task is finished."""
+
+ task = self.ctx._workgraph[pk]["_tasks"][name]
+ # print(f"set task result: {name}")
+ node = self.get_task_state_info(pk, name, "process")
+ if isinstance(node, orm.ProcessNode):
+ # print(f"set task result: {name} process")
+ state = node.process_state.value.upper()
+ if node.is_finished_ok:
+ self.set_task_state_info(pk, task["name"], "state", state)
+ if task["metadata"]["node_type"].upper() == "WORKGRAPH":
+ # expose the outputs of all the tasks in the workgraph
+ task["results"] = {}
+ outgoing = node.base.links.get_outgoing()
+ for link in outgoing.all():
+ if isinstance(link.node, ProcessNode) and getattr(
+ link.node, "process_state", False
+ ):
+ task["results"][link.link_label] = link.node.outputs
+ else:
+ task["results"] = node.outputs
+ # self.ctx._workgraph[pk]["_new_data"][name] = task["results"]
+ self.set_task_state_info(pk, task["name"], "state", "FINISHED")
+ self.task_set_context(pk, name)
+ self.report(f"Task: {name} finished.", pk=pk)
+ # all other states are considered as failed
+ else:
+ print(f"set task result: {name} failed")
+ task["results"] = node.outputs
+ # self.ctx._workgraph[pk]["_new_data"][name] = task["results"]
+ self.set_task_state_info(pk, task["name"], "state", "FAILED")
+ # set child tasks state to SKIPPED
+ self.set_tasks_state(
+ pk,
+ self.ctx._workgraph[pk]["_connectivity"]["child_node"][name],
+ "SKIPPED",
+ )
+ self.report(f"Task: {name} failed.", pk)
+ self.run_error_handlers(pk, name)
+ elif isinstance(node, orm.Data):
+ task["results"] = {task["outputs"][0]["name"]: node}
+ self.set_task_state_info(pk, task["name"], "state", "FINISHED")
+ self.task_set_context(pk, name)
+ self.report(f"Task: {name} finished.", pk)
+ else:
+ task.setdefault("results", None)
+
+ self.update_parent_task_state(pk, name)
+ self.save_workgraph_checkpoint(pk)
+ if continue_workgraph:
+ try:
+ self.continue_workgraph(pk)
+ except Exception as e:
+ print(e)
+
+ def set_normal_task_results(self, pk, name, results):
+ """Set the results of a normal task.
+ A normal task is created by decorating a function with @task().
+ """
+ task = self.ctx._workgraph[pk]["_tasks"][name]
+ if isinstance(results, tuple):
+ if len(task["outputs"]) != len(results):
+ return self.exit_codes.OUTPUS_NOT_MATCH_RESULTS
+ for i in range(len(task["outputs"])):
+ task["results"][task["outputs"][i]["name"]] = results[i]
+ elif isinstance(results, dict):
+ task["results"] = results
+ else:
+ task["results"][task["outputs"][0]["name"]] = results
+ self.task_set_context(pk, name)
+ self.set_task_state_info(pk, name, "state", "FINISHED")
+ self.report(f"Task: {name} finished.", pk=pk)
+ self.update_parent_task_state(pk, name)
+
+ def save_workgraph_checkpoint(self, pk: int):
+ """Save the workgraph checkpoint."""
+ self.ctx._workgraph[pk]["_node"].base.extras.set(
+ "_checkpoint", serialize(self.ctx._workgraph[pk])
+ )
+
+ def update_parent_task_state(self, pk, name: str) -> None:
+ """Update parent task state."""
+ parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"]
+ if parent_task[0]:
+ task_type = self.ctx._workgraph[pk]["_tasks"][parent_task[0]]["metadata"][
+ "node_type"
+ ].upper()
+ if task_type == "WHILE":
+ self.update_while_task_state(pk, parent_task[0])
+ elif task_type == "IF":
+ self.update_zone_task_state(pk, parent_task[0])
+ elif task_type == "ZONE":
+ self.update_zone_task_state(pk, parent_task[0])
+
+ def update_while_task_state(self, pk: int, name: str) -> None:
+ """Update while task state."""
+ finished, _ = self.are_childen_finished(name)
+
+ if finished:
+ self.report(
+ f"Wihle Task {name}: this iteration finished. Try to reset for the next iteration.",
+ pk,
+ )
+ # reset the condition tasks
+ for input in self.ctx._workgraph[pk]["_tasks"][name]["inputs"]:
+ if input["name"].upper() == "CONDITIONS":
+ for link in input["links"]:
+ self.reset_task(link["from_node"], recursive=False)
+ # reset the task and all its children, so that the task can run again
+ # do not reset the execution count
+ self.reset_task(name, reset_execution_count=False)
+
+ def update_zone_task_state(self, pk: int, name: str) -> None:
+ """Update zone task state."""
+ finished, _ = self.are_childen_finished(name)
+ if finished:
+ self.set_task_state_info(pk, name, "state", "FINISHED")
+ self.update_parent_task_state(pk, name)
+ self.report(f"Task: {name} finished.", pk)
+
+ def should_run_while_task(self, pk: int, name: str) -> tuple[bool, t.Any]:
+ """Check if the while task should run."""
+ # check the conditions of the while task
+ not_excess_max_iterations = (
+ self.ctx._workgraph[pk]["_tasks"][name]["execution_count"]
+ < self.ctx._workgraph[pk]["_tasks"][name]["properties"]["max_iterations"][
+ "value"
+ ]
+ )
+ conditions = [not_excess_max_iterations]
+ _, kwargs, _, _, _ = self.get_inputs(pk, name)
+ if isinstance(kwargs["conditions"], list):
+ for condition in kwargs["conditions"]:
+ value = get_nested_dict(self.ctx, condition)
+ conditions.append(value)
+ elif isinstance(kwargs["conditions"], dict):
+ for _, value in kwargs["conditions"].items():
+ conditions.append(value)
+ else:
+ conditions.append(kwargs["conditions"])
+ return False not in conditions
+
+ def should_run_if_task(self, name: str) -> tuple[bool, t.Any]:
+ """Check if the IF task should run."""
+ _, kwargs, _, _, _ = self.get_inputs(pk, name)
+ flag = kwargs["conditions"]
+ if kwargs["invert_condition"]:
+ return not flag
+ return flag
+
+ def are_childen_finished(self, pk, name: str) -> tuple[bool, t.Any]:
+ """Check if the child tasks are finished."""
+ task = self.ctx._workgraph[pk]["_tasks"][name]
+ finished = True
+ for name in task["children"]:
+ if self.get_task_state_info(pk, name, "state") not in [
+ "FINISHED",
+ "SKIPPED",
+ "FAILED",
+ ]:
+ finished = False
+ break
+ return finished, None
+
+ def run_error_handlers(self, pk: int, task_name: str) -> None:
+ """Run error handler."""
+ node = self.get_task_state_info(pk, task_name, "process")
+ if not node or not node.exit_status:
+ return
+ for _, data in self.ctx._workgraph[pk]["_error_handlers"].items():
+ if task_name in data["tasks"]:
+ handler = data["handler"]
+ metadata = data["tasks"][task_name]
+ if node.exit_code.status in metadata.get("exit_codes", []):
+ self.report(f"Run error handler: {metadata}", pk)
+ metadata.setdefault("retry", 0)
+ if metadata["retry"] < metadata["max_retries"]:
+ handler(self, task_name, **metadata.get("kwargs", {}))
+ metadata["retry"] += 1
+
+ def is_workgraph_finished(self, pk) -> bool:
+ """Check if the workgraph is finished.
+ For `while` workgraph, we need check its conditions"""
+ is_finished = True
+ failed_tasks = []
+ for name, task in self.ctx._workgraph[pk]["_tasks"].items():
+ # self.update_task_state(pk, name)
+ if self.get_task_state_info(pk, task["name"], "state") in [
+ "RUNNING",
+ "CREATED",
+ "PLANNED",
+ "READY",
+ ]:
+ is_finished = False
+ elif self.get_task_state_info(pk, task["name"], "state") == "FAILED":
+ failed_tasks.append(name)
+ if is_finished:
+ if (
+ self.ctx._workgraph[pk]["_workgraph"]["workgraph_type"].upper()
+ == "WHILE"
+ ):
+ should_run = self.check_while_conditions(pk)
+ is_finished = not should_run
+ if self.ctx._workgraph[pk]["_workgraph"]["workgraph_type"].upper() == "FOR":
+ should_run = self.check_for_conditions(pk)
+ is_finished = not should_run
+ if is_finished and len(failed_tasks) > 0:
+ message = f"WorkGraph finished, but tasks: {failed_tasks} failed. Thus all their child tasks are skipped."
+ self.report(message, pk)
+ result = ExitCode(302, message)
+ else:
+ result = None
+ return is_finished, result
+
+ def check_while_conditions(self, pk: int) -> bool:
+ """Check while conditions.
+ Run all condition tasks and check if all the conditions are True.
+ """
+ if (
+ self.ctx._workgraph[pk]["_execution_count"]
+ >= self.ctx._workgraph[pk]["_max_iteration"]
+ ):
+ self.report("Max iteration reached.", pk)
+ return False
+ condition_tasks = []
+ for c in self.ctx._workgraph[pk]["conditions"]:
+ task_name, socket_name = c.split(".")
+ if "task_name" != "context":
+ condition_tasks.append(task_name)
+ self.run_tasks(condition_tasks, continue_workgraph=False)
+ conditions = []
+ for c in self.ctx._workgraph[pk]["conditions"]:
+ task_name, socket_name = c.split(".")
+ if task_name == "context":
+ conditions.append(self.ctx[socket_name])
+ else:
+ conditions.append(
+ self.ctx._workgraph[pk]["_tasks"][task_name]["results"][socket_name]
+ )
+ should_run = False not in conditions
+ if should_run:
+ self.reset_workgraph(pk)
+ self.set_tasks_state(pk, condition_tasks, "SKIPPED")
+ return should_run
+
+ def check_for_conditions(self, pk: int) -> bool:
+ condition_tasks = [c[0] for c in self.ctx._workgraph[pk]["conditions"]]
+ self.run_tasks(condition_tasks)
+ conditions = [self.ctx._count < len(self.ctx._sequence)] + [
+ self.ctx._workgraph[pk]["_tasks"][c[0]]["results"][c[1]]
+ for c in self.ctx._workgraph[pk]["conditions"]
+ ]
+ should_run = False not in conditions
+ if should_run:
+ self.reset_workgraph(pk)
+ self.set_tasks_state(pk, condition_tasks, "SKIPPED")
+ self.ctx["i"] = self.ctx._sequence[self.ctx._count]
+ self.ctx._count += 1
+ return should_run
+
+ def remove_executed_task(self, pk, name: str) -> None:
+ """Remove labels with name from executed tasks."""
+ self.ctx._workgraph[pk]["_executed_tasks"] = [
+ label
+ for label in self.ctx._workgraph[pk]["_executed_tasks"]
+ if label.split(".")[0] != name
+ ]
+
+ def add_task_link(self, pk, node: ProcessNode) -> None:
+ from aiida.common.links import LinkType
+
+ parent_calc = self.ctx._workgraph[pk]["_node"]
+ if isinstance(node, orm.CalculationNode):
+ node.base.links.add_incoming(
+ parent_calc, LinkType.CALL_CALC, "CALL"
+ ) # TODO, self.metadata.call_link_label)
+ elif isinstance(node, orm.WorkflowNode):
+ node.base.links.add_incoming(
+ parent_calc, LinkType.CALL_WORK, "CALL"
+ ) # TODO, self.metadata.call_link_label)
+
+ def run_tasks(
+ self, pk: int, names: t.List[str], continue_workgraph: bool = True
+ ) -> None:
+ """Run tasks.
+ Task type includes: Node, Data, CalcFunction, WorkFunction, CalcJob, WorkChain, GraphBuilder,
+ WorkGraph, PythonJob, ShellJob, While, If, Zone, FromContext, ToContext, Normal.
+
+ Here we use ToContext to pass the results of the run to the next step.
+ This will force the engine to wait for all the submitted processes to
+ finish before continuing to the next step.
+ """
+ from aiida_workgraph.utils import (
+ get_executor,
+ create_data_node,
+ update_nested_dict_with_special_keys,
+ )
+ from aiida_workgraph.engine.workgraph import WorkGraphEngine
+ from aiida_workgraph.engine import launch
+
+ for name in names:
+ # skip if the max number of awaitables is reached
+ task = self.ctx._workgraph[pk]["_tasks"][name]
+ if task["metadata"]["node_type"].upper() in [
+ "CALCJOB",
+ "WORKCHAIN",
+ "GRAPH_BUILDER",
+ "WORKGRAPH",
+ "PYTHONJOB",
+ "SHELLJOB",
+ ]:
+ if len(self._awaitables) >= self.ctx._max_number_awaitables:
+ print(
+ MAX_NUMBER_AWAITABLES_MSG.format(
+ self.ctx._max_number_awaitables, name
+ )
+ )
+ continue
+ # skip if the task is already executed
+ if name in self.ctx._workgraph[pk]["_executed_tasks"]:
+ continue
+ self.ctx._workgraph[pk]["_executed_tasks"].append(name)
+ print("-" * 60)
+
+ self.report(f"Run task: {name}, type: {task['metadata']['node_type']}", pk)
+ executor, _ = get_executor(task["executor"])
+ # print("executor: ", executor)
+ args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(pk, name)
+ for i, key in enumerate(
+ self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]
+ ):
+ kwargs[key] = args[i]
+ # update the port namespace
+ kwargs = update_nested_dict_with_special_keys(kwargs)
+ # print("args: ", args)
+ # print("kwargs: ", kwargs)
+ # print("var_kwargs: ", var_kwargs)
+ # kwargs["meta.label"] = name
+ # output must be a Data type or a mapping of {string: Data}
+ task["results"] = {}
+ if task["metadata"]["node_type"].upper() == "NODE":
+ results = self.run_executor(executor, [], kwargs, var_args, var_kwargs)
+ self.set_task_state_info(pk, name, "process", results)
+ self.update_task_state(pk, name)
+ if continue_workgraph:
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() == "DATA":
+ for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]:
+ kwargs.pop(key, None)
+ results = create_data_node(executor, args, kwargs)
+ self.set_task_state_info(pk, name, "process", results)
+ self.update_task_state(pk, name)
+ self.ctx._workgraph[pk]["_new_data"][name] = results
+ if continue_workgraph:
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() in [
+ "CALCFUNCTION",
+ "WORKFUNCTION",
+ ]:
+ kwargs.setdefault("metadata", {})
+ kwargs["metadata"].update({"call_link_label": name})
+ kwargs["parent_pid"] = pk
+ try:
+ # since aiida 2.5.0, we need to use args_dict to pass the args to the run_get_node
+ if var_kwargs is None:
+ results, process = launch.run_get_node(
+ executor.process_class, **kwargs
+ )
+ else:
+ results, process = launch.run_get_node(
+ executor.process_class, **kwargs, **var_kwargs
+ )
+ process.label = name
+ # print("results: ", results)
+ self.set_task_state_info(pk, name, "process", process)
+ self.update_task_state(pk, name)
+ except Exception as e:
+ self.logger.error(f"Error in task {name}: {e}")
+ self.set_task_state_info(pk, name, "state", "FAILED")
+ # set child state to FAILED
+ self.set_tasks_state(
+ pk,
+ self.ctx._workgraph[pk]["_connectivity"]["child_node"][name],
+ "SKIPPED",
+ )
+ self.report(f"Task: {name} failed.", pk)
+ # exclude the current tasks from the next run
+ if continue_workgraph:
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() in ["CALCJOB", "WORKCHAIN"]:
+ # process = run_get_node(executor, *args, **kwargs)
+ kwargs.setdefault("metadata", {})
+ kwargs["metadata"].update({"call_link_label": name})
+ kwargs["parent_pid"] = pk
+ try:
+ # transfer the args to kwargs
+ if self.get_task_state_info(pk, name, "action").upper() == "PAUSE":
+ self.set_task_state_info(pk, name, "action", "")
+ self.report(f"Task {name} is created and paused.", pk)
+ process = create_and_pause_process(
+ self.runner,
+ executor,
+ kwargs,
+ state_msg="Paused through WorkGraph",
+ )
+ self.set_task_state_info(pk, name, "state", "CREATED")
+ process = process.node
+ else:
+ process = launch.submit(executor, **kwargs)
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ process.label = name
+ process.workgraph_pk = pk
+ self.set_task_state_info(pk, name, "process", process)
+ self.to_context(**{name: process})
+ except Exception as e:
+ self.set_task_state_info(pk, name, "state", "FAILED")
+ # set child state to FAILED
+ self.set_tasks_state(
+ pk,
+ self.ctx._workgraph[pk]["_connectivity"]["child_node"][name],
+ "SKIPPED",
+ )
+ self.report(f"Error in task {name}: {e}", pk)
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]:
+ from aiida_workgraph.utils.control import create_workgraph_action
+
+ wg = self.run_executor(executor, [], kwargs, var_args, var_kwargs)
+ wg.name = name
+ wg.group_outputs = self.ctx._workgraph[pk]["_tasks"][name]["metadata"][
+ "group_outputs"
+ ]
+ metadata = {"call_link_label": name}
+ try:
+ wg.save(metadata=metadata, parent_pid=pk)
+ process = wg.process
+ create_workgraph_action(process.pk)
+ process.workgraph_pk = pk
+ self.set_task_state_info(pk, name, "process", process)
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ self.report(f"GraphBuilder {name} is created.", pk)
+ self.to_context(**{name: process})
+ except Exception as e:
+ print("Error in launching the workgraph: ", e)
+ elif task["metadata"]["node_type"].upper() in ["WORKGRAPH"]:
+ from .utils import prepare_for_workgraph_task
+
+ inputs, _ = prepare_for_workgraph_task(task, kwargs)
+ process = launch.submit(WorkGraphEngine, inputs=inputs, parent_pid=pk)
+ process.workgraph_pk = pk
+ self.set_task_state_info(pk, name, "process", process)
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ self.to_context(**{name: process})
+ elif task["metadata"]["node_type"].upper() in ["PYTHONJOB"]:
+ from aiida_workgraph.calculations.python import PythonJob
+ from .utils import prepare_for_python_task
+
+ inputs = prepare_for_python_task(task, kwargs, var_kwargs)
+ # since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs
+ if self.get_task_state_info(pk, name, "action").upper() == "PAUSE":
+ self.set_task_state_info(pk, name, "action", "")
+ self.report(f"Task {name} is created and paused.", pk)
+ process = create_and_pause_process(
+ self.runner,
+ PythonJob,
+ inputs,
+ state_msg="Paused through WorkGraph",
+ )
+ self.set_task_state_info(pk, name, "state", "CREATED")
+ process = process.node
+ else:
+ process = launch.submit(PythonJob, inputs=inputs, parent_pid=pk)
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ process.label = name
+ process.workgraph_pk = pk
+ self.set_task_state_info(pk, name, "process", process)
+ self.to_context(**{name: process})
+ elif task["metadata"]["node_type"].upper() in ["SHELLJOB"]:
+ from aiida_shell.calculations.shell import ShellJob
+ from .utils import prepare_for_shell_task
+
+ inputs = prepare_for_shell_task(task, kwargs)
+ if self.get_task_state_info(pk, name, "action").upper() == "PAUSE":
+ self.set_task_state_info(pk, name, "action", "")
+ self.report(f"Task {name} is created and paused.", pk)
+ process = create_and_pause_process(
+ self.runner,
+ ShellJob,
+ inputs,
+ state_msg="Paused through WorkGraph",
+ )
+ self.set_task_state_info(pk, name, "state", "CREATED")
+ process = process.node
+ else:
+ process = launch.submit(ShellJob, inputs=inputs, parent_pid=pk)
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ process.label = name
+ process.workgraph_pk = pk
+ self.set_task_state_info(pk, name, "process", process)
+ self.to_context(**{name: process})
+ elif task["metadata"]["node_type"].upper() in ["WHILE"]:
+ # check the conditions of the while task
+ should_run = self.should_run_while_task(name)
+ if not should_run:
+ self.set_task_state_info(pk, name, "state", "FINISHED")
+ self.set_tasks_state(
+ pk,
+ self.ctx._workgraph[pk]["_tasks"][name]["children"],
+ "SKIPPED",
+ )
+ self.update_parent_task_state(pk, name)
+ self.report(
+ f"While Task {name}: Condition not fullilled, task finished. Skip all its children.",
+ pk,
+ )
+ else:
+ task["execution_count"] += 1
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() in ["IF"]:
+ should_run = self.should_run_if_task(name)
+ if should_run:
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ else:
+ self.set_tasks_state(pk, task["children"], "SKIPPED")
+ self.update_zone_task_state(pk, name)
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() in ["ZONE"]:
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() in ["FROM_CONTEXT"]:
+ # get the results from the context
+ results = {"result": getattr(self.ctx, kwargs["key"])}
+ task["results"] = results
+ self.set_task_state_info(pk, name, "state", "FINISHED")
+ self.update_parent_task_state(pk, name)
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() in ["TO_CONTEXT"]:
+ # get the results from the context
+ setattr(self.ctx, kwargs["key"], kwargs["value"])
+ self.set_task_state_info(pk, name, "state", "FINISHED")
+ self.update_parent_task_state(pk, name)
+ self.continue_workgraph(pk)
+ elif task["metadata"]["node_type"].upper() in ["AWAITABLE"]:
+ for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]:
+ kwargs.pop(key, None)
+ awaitable_target = asyncio.ensure_future(
+ self.run_executor(executor, args, kwargs, var_args, var_kwargs),
+ loop=self.loop,
+ )
+ awaitable = self.construct_awaitable_function(name, awaitable_target)
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ self.to_context(**{name: awaitable})
+ elif task["metadata"]["node_type"].upper() in ["MONITOR"]:
+
+ for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]:
+ kwargs.pop(key, None)
+ # add function and interval to the args
+ args = [executor, kwargs.pop("interval"), kwargs.pop("timeout"), *args]
+ awaitable_target = asyncio.ensure_future(
+ self.run_executor(monitor, args, kwargs, var_args, var_kwargs),
+ loop=self.loop,
+ )
+ awaitable = self.construct_awaitable_function(name, awaitable_target)
+ self.set_task_state_info(pk, name, "state", "RUNNING")
+ # save the awaitable to the temp, so that we can kill it if needed
+ self._temp["awaitables"][name] = awaitable_target
+ self.to_context(**{name: awaitable})
+ elif task["metadata"]["node_type"].upper() in ["NORMAL"]:
+ # Normal task is created by decoratoring a function with @task()
+ if "context" in task["metadata"]["kwargs"]:
+ self.ctx.task_name = name
+ kwargs.update({"context": self.ctx})
+ for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]:
+ kwargs.pop(key, None)
+ try:
+ results = self.run_executor(
+ executor, args, kwargs, var_args, var_kwargs
+ )
+ self.set_normal_task_results(pk, name, results)
+ except Exception as e:
+ self.logger.error(f"Error in task {name}: {e}")
+ self.set_task_state_info(pk, name, "state", "FAILED")
+ # set child tasks state to SKIPPED
+ self.set_tasks_state(
+ pk,
+ self.ctx._workgraph[pk]["_connectivity"]["child_node"][name],
+ "SKIPPED",
+ )
+ self.report(f"Task: {name} failed.", pk)
+ self.run_error_handlers(pk, name)
+ if continue_workgraph:
+ self.continue_workgraph(pk)
+ else:
+ # self.report("Unknow task type {}".format(task["metadata"]["node_type"]))
+ return self.exit_codes.UNKNOWN_TASK_TYPE
+
+ def construct_awaitable_function(
+ self, name: str, awaitable_target: Awaitable
+ ) -> None:
+ """Construct the awaitable function."""
+ awaitable = Awaitable(
+ **{
+ "pk": name,
+ "action": AwaitableAction.ASSIGN,
+ "target": "asyncio.tasks.Task",
+ "outputs": False,
+ }
+ )
+ awaitable_target.key = name
+ awaitable_target.pk = name
+ awaitable_target.action = AwaitableAction.ASSIGN
+ awaitable_target.add_done_callback(self._on_awaitable_finished)
+ return awaitable
+
+ def get_inputs(
+ self, pk: int, name: str
+ ) -> t.Tuple[
+ t.List[t.Any],
+ t.Dict[str, t.Any],
+ t.Optional[t.List[t.Any]],
+ t.Optional[t.Dict[str, t.Any]],
+ t.Dict[str, t.Any],
+ ]:
+ """Get input based on the links."""
+
+ args = []
+ args_dict = {}
+ kwargs = {}
+ var_args = None
+ var_kwargs = None
+ task = self.ctx._workgraph[pk]["_tasks"][name]
+ properties = task.get("properties", {})
+ inputs = {}
+ for input in task["inputs"]:
+ # print(f"input: {input['name']}")
+ if len(input["links"]) == 0:
+ inputs[input["name"]] = self.update_context_variable(
+ properties[input["name"]]["value"]
+ )
+ elif len(input["links"]) == 1:
+ link = input["links"][0]
+ if (
+ self.ctx._workgraph[pk]["_tasks"][link["from_node"]]["results"]
+ is None
+ ):
+ inputs[input["name"]] = None
+ else:
+ # handle the special socket _wait, _outputs
+ if link["from_socket"] == "_wait":
+ continue
+ elif link["from_socket"] == "_outputs":
+ inputs[input["name"]] = self.ctx._workgraph[pk]["_tasks"][
+ link["from_node"]
+ ]["results"]
+ else:
+ inputs[input["name"]] = get_nested_dict(
+ self.ctx._workgraph[pk]["_tasks"][link["from_node"]][
+ "results"
+ ],
+ link["from_socket"],
+ )
+ # handle the case of multiple outputs
+ elif len(input["links"]) > 1:
+ value = {}
+ for link in input["links"]:
+ name = f'{link["from_node"]}_{link["from_socket"]}'
+ # handle the special socket _wait, _outputs
+ if link["from_socket"] == "_wait":
+ continue
+ if (
+ self.ctx._workgraph[pk]["_tasks"][link["from_node"]]["results"]
+ is None
+ ):
+ value[name] = None
+ else:
+ value[name] = self.ctx._workgraph[pk]["_tasks"][
+ link["from_node"]
+ ]["results"][link["from_socket"]]
+ inputs[input["name"]] = value
+ for name in task["metadata"].get("args", []):
+ if name in inputs:
+ args.append(inputs[name])
+ args_dict[name] = inputs[name]
+ else:
+ value = self.update_context_variable(properties[name]["value"])
+ args.append(value)
+ args_dict[name] = value
+ for name in task["metadata"].get("kwargs", []):
+ if name in inputs:
+ kwargs[name] = inputs[name]
+ else:
+ value = self.update_context_variable(properties[name]["value"])
+ kwargs[name] = value
+ if task["metadata"]["var_args"] is not None:
+ name = task["metadata"]["var_args"]
+ if name in inputs:
+ var_args = inputs[name]
+ else:
+ value = self.update_context_variable(properties[name]["value"])
+ var_args = value
+ if task["metadata"]["var_kwargs"] is not None:
+ name = task["metadata"]["var_kwargs"]
+ if name in inputs:
+ var_kwargs = inputs[name]
+ else:
+ value = self.update_context_variable(properties[name]["value"])
+ var_kwargs = value
+ return args, kwargs, var_args, var_kwargs, args_dict
+
+ def update_context_variable(self, value: t.Any) -> t.Any:
+ # replace context variables
+
+ """Get value from context."""
+ if isinstance(value, dict):
+ for key, sub_value in value.items():
+ value[key] = self.update_context_variable(sub_value)
+ elif (
+ isinstance(value, str)
+ and value.strip().startswith("{{")
+ and value.strip().endswith("}}")
+ ):
+ name = value[2:-2].strip()
+ return get_nested_dict(self.ctx, name)
+ return value
+
+ def task_set_context(self, pk, name: str) -> None:
+ """Export task result to context."""
+ from aiida_workgraph.utils import update_nested_dict
+
+ items = self.ctx._workgraph[pk]["_tasks"][name]["context_mapping"]
+ for key, value in items.items():
+ result = self.ctx._workgraph[pk]["_tasks"][name]["results"][key]
+ update_nested_dict(self.ctx, value, result)
+
+ def is_task_ready_to_run(self, pk, name: str) -> t.Tuple[bool, t.Optional[str]]:
+ """Check if the task ready to run.
+ For normal task and a zone task, we need to check its input tasks in the connectivity["zone"].
+ For task inside a zone, we need to check if the zone (parent task) is ready.
+ """
+ parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"]
+ # input_tasks, parent_task, conditions
+ parent_states = [True, True]
+ # if the task belongs to a parent zone
+ if parent_task[0]:
+ state = self.get_task_state_info(pk, parent_task[0], "state")
+ if state not in ["RUNNING"]:
+ parent_states[1] = False
+ # check the input tasks of the zone
+ # check if the zone input tasks are ready
+ for child_task_name in self.ctx._workgraph[pk]["_connectivity"]["zone"][name][
+ "input_tasks"
+ ]:
+ if self.get_task_state_info(pk, child_task_name, "state") not in [
+ "FINISHED",
+ "SKIPPED",
+ "FAILED",
+ ]:
+ parent_states[0] = False
+ break
+
+ return all(parent_states), parent_states
+
+ def reset_workgraph(self, pk) -> None:
+ self.ctx._workgraph[pk]["_execution_count"] += 1
+ self.set_tasks_state(pk, self.ctx._workgraph[pk]["_tasks"].keys(), "PLANNED")
+ self.ctx._workgraph[pk]["_executed_tasks"] = []
+
+ def set_tasks_state(
+ self, pk: int, tasks: t.Union[t.List[str], t.Sequence[str]], value: str
+ ) -> None:
+ """Set tasks state"""
+ for name in tasks:
+ self.set_task_state_info(pk, name, "state", value)
+ if "children" in self.ctx._workgraph[pk]["_tasks"][name]:
+ self.set_tasks_state(
+ pk, self.ctx._workgraph[pk]["_tasks"][name]["children"], value
+ )
+
+ def run_executor(
+ self,
+ executor: t.Callable,
+ args: t.List[t.Any],
+ kwargs: t.Dict[str, t.Any],
+ var_args: t.Optional[t.List[t.Any]],
+ var_kwargs: t.Optional[t.Dict[str, t.Any]],
+ ) -> t.Any:
+ if var_kwargs is None:
+ return executor(*args, **kwargs)
+ else:
+ return executor(*args, **kwargs, **var_kwargs)
+
+ def save_results_to_extras(self, name: str) -> None:
+ """Save the results to the base.extras.
+ For the outputs of a Normal task, they are not saved to the database like the calcjob or workchain.
+ One temporary solution is to save the results to the base.extras. In order to do this, we need to
+ serialize the results
+ """
+ from aiida_workgraph.utils import get_executor
+
+ results = self.ctx._workgraph[pk]["_tasks"][name]["results"]
+ if results is None:
+ return
+ datas = {}
+ for key, value in results.items():
+ # find outptus sockets with the name as key
+ output = [
+ output
+ for output in self.ctx._workgraph[pk]["_tasks"][name]["outputs"]
+ if output["name"] == key
+ ]
+ if len(output) == 0:
+ continue
+ output = output[0]
+ Executor, _ = get_executor(output["serialize"])
+ datas[key] = Executor(value)
+ self.node.set_extra(f"nodes__results__{name}", datas)
+
+ def message_receive(
+ self, _comm: kiwipy.Communicator, msg: t.Dict[str, t.Any]
+ ) -> t.Any:
+ """
+ Coroutine called when the process receives a message from the communicator
+
+ :param _comm: the communicator that sent the message
+ :param msg: the message
+ :return: the outcome of processing the message, the return value will be sent back as a response to the sender
+ """
+ self.logger.debug(
+ "Process<%s>: received RPC message with communicator '%s': %r",
+ self.pid,
+ _comm,
+ msg,
+ )
+
+ intent = msg[process_comms.INTENT_KEY]
+
+ if intent == process_comms.Intent.PLAY:
+ return self._schedule_rpc(self.play)
+ if intent == process_comms.Intent.PAUSE:
+ return self._schedule_rpc(
+ self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)
+ )
+ if intent == process_comms.Intent.KILL:
+ return self._schedule_rpc(
+ self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)
+ )
+ if intent == process_comms.Intent.STATUS:
+ status_info: t.Dict[str, t.Any] = {}
+ self.get_status_info(status_info)
+ return status_info
+ if intent == "custom":
+ return self._schedule_rpc(self.apply_action, msg=msg)
+
+ # Didn't match any known intents
+ raise RuntimeError("Unknown intent")
+
+ def call_on_receive_workgraph_message(self, _comm, msg):
+ """Call on receive workgraph message."""
+ # self.report(f"Received workgraph message: {msg}")
+ pk = msg["args"]["pid"]
+ # To avoid "DbNode is not persistent", we need to schedule the call
+ self._schedule_rpc(self.launch_workgraph, pk=pk)
+ return True
+
+ def add_workgraph_subsriber(self) -> None:
+ """Add workgraph subscriber."""
+ queue_name = "workgraph_queue"
+ self.report(f"Add workgraph subscriber on queue: {queue_name}")
+ comm = self.runner.communicator._communicator
+ queue = comm.task_queue(queue_name, prefetch_count=1000)
+ queue.add_task_subscriber(self.call_on_receive_workgraph_message)
+
+ def finalize_workgraph(self, pk: int) -> t.Optional[ExitCode]:
+ """"""
+ # expose outputs of the workgraph
+ group_outputs = {}
+ for output in self.ctx._workgraph[pk]["_metadata"]["group_outputs"]:
+ names = output["from"].split(".", 1)
+ if names[0] == "context":
+ if len(names) == 1:
+ raise ValueError("The output name should be context.key")
+ update_nested_dict(
+ group_outputs,
+ output["name"],
+ get_nested_dict(self.ctx, names[1]),
+ )
+ else:
+ # expose the whole outputs of the tasks
+ if len(names) == 1:
+ update_nested_dict(
+ group_outputs,
+ output["name"],
+ self.ctx._workgraph[pk]["_tasks"][names[0]]["results"],
+ )
+ else:
+ # expose one output of the task
+ # note, the output may not exist
+ if (
+ names[1]
+ in self.ctx._workgraph[pk]["_tasks"][names[0]]["results"]
+ ):
+ update_nested_dict(
+ group_outputs,
+ output["name"],
+ self.ctx._workgraph[pk]["_tasks"][names[0]]["results"][
+ names[1]
+ ],
+ )
+ # output the new data
+ if self.ctx._workgraph[pk]["_new_data"]:
+ group_outputs["new_data"] = self.ctx._workgraph[pk]["_new_data"]
+ group_outputs["execution_count"] = orm.Int(
+ self.ctx._workgraph[pk]["_execution_count"]
+ ).store()
+ self.update_workgraph_output(pk, group_outputs)
+ self.report("Finalize workgraph.", pk)
+ self.report(f"Finalize workgraph {pk}.")
+ for name, task in self.ctx._workgraph[pk]["_tasks"].items():
+ if self.get_task_state_info(pk, name, "state") == "FAILED":
+ self.report(f"Task {name} failed.", pk)
+ self.ctx._workgraph[pk]["_node"].set_exit_status(302)
+ self.ctx._workgraph[pk]["_node"].set_exit_message(
+ "Some of the tasks failed."
+ )
+ self.broadcast_workgraph_state(pk, "excepted")
+ self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL)
+ self.ctx._workgraph[pk]["_node"].seal()
+ return
+ # send a broadcast message to announce the workgraph is finished
+ self.broadcast_workgraph_state(pk, "finished")
+ self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL)
+ self.ctx._workgraph[pk]["_node"].seal()
+
+ def update_workgraph_output(self, pk, outputs: dict) -> None:
+ """Update the workgraph output."""
+ from aiida.common.links import LinkType
+
+ node = self.ctx._workgraph[pk]["_node"]
+
+ def add_link(node, outputs):
+ for link_label, output in outputs.items():
+ if isinstance(output, dict):
+ add_link(node, output)
+ if isinstance(node, orm.CalculationNode):
+ output.base.links.add_incoming(node, LinkType.CREATE, link_label)
+ elif isinstance(node, orm.WorkflowNode):
+ output.base.links.add_incoming(node, LinkType.RETURN, link_label)
+
+ add_link(node, outputs)
+
+ def broadcast_workgraph_state(self, pk: int, state: str) -> None:
+ """Workgraph on entered."""
+ from aio_pika.exceptions import ConnectionClosed
+
+ from_label = None
+ subject = f"state_changed.{from_label}.{state}"
+ try:
+ self._communicator.broadcast_send(body=None, sender=pk, subject=subject)
+ except ConnectionClosed:
+ message = "Process<%s>: no connection available to broadcast state change from %s to %s"
+ self.logger.warning(message, pk, from_label, state)
+ except kiwipy.TimeoutError:
+ message = (
+ "Process<%s>: sending broadcast of state change from %s to %s timed out"
+ )
+ self.logger.warning(message, pk, from_label, state)
diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py
index 82332a78..e35f4fc5 100644
--- a/aiida_workgraph/engine/utils.py
+++ b/aiida_workgraph/engine/utils.py
@@ -1,6 +1,15 @@
from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes
from aiida import orm
from aiida.common.extendeddicts import AttributeDict
+from aiida.engine.utils import is_process_function
+from aiida.engine.processes.builder import ProcessBuilder
+from aiida.engine.processes.process import Process
+
+import inspect
+from typing import (
+ Type,
+ Union,
+)
def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple:
@@ -139,3 +148,45 @@ def prepare_for_shell_task(task: dict, kwargs: dict) -> dict:
"metadata": metadata or {},
}
return inputs
+
+
+def instantiate_process(
+ runner: "Runner",
+ process: Union["Process", Type["Process"], "ProcessBuilder"],
+ parent_pid=None,
+ **inputs,
+) -> "Process":
+ """Return an instance of the process with the given inputs. The function can deal with various types
+ of the `process`:
+
+ * Process instance: will simply return the instance
+ * ProcessBuilder instance: will instantiate the Process from the class and inputs defined within it
+ * Process class: will instantiate with the specified inputs
+
+ If anything else is passed, a ValueError will be raised
+
+ :param process: Process instance or class, CalcJobNode class or ProcessBuilder instance
+ :param inputs: the inputs for the process to be instantiated with
+ """
+
+ if isinstance(process, Process):
+ assert not inputs
+ assert runner is process.runner
+ return process
+
+ if isinstance(process, ProcessBuilder):
+ builder = process
+ process_class = builder.process_class
+ inputs.update(**builder._inputs(prune=True))
+ elif is_process_function(process):
+ process_class = process.process_class # type: ignore[attr-defined]
+ elif inspect.isclass(process) and issubclass(process, Process):
+ process_class = process
+ else:
+ raise ValueError(
+ f"invalid process {type(process)}, needs to be Process or ProcessBuilder"
+ )
+
+ process = process_class(runner=runner, inputs=inputs, parent_pid=parent_pid)
+
+ return process
diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py
index f2eacd23..46e803ba 100644
--- a/aiida_workgraph/engine/workgraph.py
+++ b/aiida_workgraph/engine/workgraph.py
@@ -60,6 +60,7 @@ def __init__(
logger: logging.Logger | None = None,
runner: "Runner" | None = None,
enable_persistence: bool = True,
+ **kwargs: t.Any,
) -> None:
"""Construct a WorkGraph instance.
@@ -70,7 +71,9 @@ def __init__(
"""
- super().__init__(inputs, logger, runner, enable_persistence=enable_persistence)
+ super().__init__(
+ inputs, logger, runner, enable_persistence=enable_persistence, **kwargs
+ )
self._awaitables: list[Awaitable] = []
self._context = AttributeDict()
diff --git a/aiida_workgraph/tasks/test.py b/aiida_workgraph/tasks/test.py
index e0c2bfa4..78ab58f7 100644
--- a/aiida_workgraph/tasks/test.py
+++ b/aiida_workgraph/tasks/test.py
@@ -1,5 +1,6 @@
from typing import Dict
from aiida_workgraph.task import Task
+from aiida_workgraph.decorator import task
class TestAdd(Task):
@@ -114,3 +115,19 @@ def get_executor(self) -> Dict[str, str]:
"name": "core.arithmetic.multiply_add",
"type": "WorkflowFactory",
}
+
+
+@task.graph_builder()
+def create_workgraph_recrusively(n):
+ from aiida_workgraph import WorkGraph
+ from aiida_workgraph.tasks.test import create_workgraph_recrusively
+ from aiida.calculations.arithmetic.add import ArithmeticAddCalculation
+ from aiida import orm
+
+ code = orm.load_code("add@localhost")
+ wg = WorkGraph(f"n-{n}")
+ if n > 0:
+ wg.add_task(create_workgraph_recrusively, name=f"n_{n}", n=n - 1)
+ else:
+ wg.add_task(ArithmeticAddCalculation, code=code, x=1, y=2)
+ return wg
diff --git a/aiida_workgraph/utils/control.py b/aiida_workgraph/utils/control.py
index 376f8fc3..4fa50be6 100644
--- a/aiida_workgraph/utils/control.py
+++ b/aiida_workgraph/utils/control.py
@@ -1,6 +1,7 @@
from aiida.manage import get_manager
from aiida import orm
from aiida.engine.processes import control
+from aiida_workgraph.engine.override import ControllerWithQueueName
def create_task_action(
@@ -17,6 +18,30 @@ def create_task_action(
controller._communicator.rpc_send(pk, message)
+def create_scheduler_action(
+ pk: int,
+):
+ """Send workgraph task to scheduler."""
+
+ manager = get_manager()
+ controller = ControllerWithQueueName(
+ queue_name="scheduler_queue", communicator=manager.get_communicator()
+ )
+ controller.continue_process(pk, nowait=False)
+
+
+def create_workgraph_action(
+ pk: int,
+):
+ """Send workgraph task to scheduler."""
+
+ manager = get_manager()
+ controller = ControllerWithQueueName(
+ queue_name="workgraph_queue", communicator=manager.get_communicator()
+ )
+ controller.continue_process(pk, nowait=False)
+
+
def get_task_state_info(node, name: str, key: str) -> str:
"""Get task state info from base.extras."""
from aiida.orm.utils.serialize import deserialize_unsafe
diff --git a/aiida_workgraph/web/backend/app/api.py b/aiida_workgraph/web/backend/app/api.py
index d05b527b..4bf69c91 100644
--- a/aiida_workgraph/web/backend/app/api.py
+++ b/aiida_workgraph/web/backend/app/api.py
@@ -2,6 +2,7 @@
from fastapi.middleware.cors import CORSMiddleware
from aiida.manage import manager
from aiida_workgraph.web.backend.app.daemon import router as daemon_router
+from aiida_workgraph.web.backend.app.scheduler import router as scheduler_router
from aiida_workgraph.web.backend.app.workgraph import router as workgraph_router
from aiida_workgraph.web.backend.app.datanode import router as datanode_router
from fastapi.staticfiles import StaticFiles
@@ -47,6 +48,7 @@ async def read_root() -> dict:
app.include_router(workgraph_router)
app.include_router(datanode_router)
app.include_router(daemon_router)
+app.include_router(scheduler_router)
@app.get("/debug")
diff --git a/aiida_workgraph/web/backend/app/daemon.py b/aiida_workgraph/web/backend/app/daemon.py
index af069d7c..caa22cf4 100644
--- a/aiida_workgraph/web/backend/app/daemon.py
+++ b/aiida_workgraph/web/backend/app/daemon.py
@@ -22,7 +22,7 @@ class DaemonStatusModel(BaseModel):
)
-@router.get("/api/daemon/status", response_model=DaemonStatusModel)
+@router.get("/api/daemon/task/status", response_model=DaemonStatusModel)
@with_dbenv()
async def get_daemon_status() -> DaemonStatusModel:
"""Return the daemon status."""
@@ -36,7 +36,7 @@ async def get_daemon_status() -> DaemonStatusModel:
return DaemonStatusModel(running=True, num_workers=response["numprocesses"])
-@router.get("/api/daemon/worker")
+@router.get("/api/daemon/task/worker")
@with_dbenv()
async def get_daemon_worker():
"""Return the daemon status."""
@@ -50,7 +50,7 @@ async def get_daemon_worker():
return response["info"]
-@router.post("/api/daemon/start", response_model=DaemonStatusModel)
+@router.post("/api/daemon/task/start", response_model=DaemonStatusModel)
@with_dbenv()
async def get_daemon_start() -> DaemonStatusModel:
"""Start the daemon."""
@@ -69,7 +69,7 @@ async def get_daemon_start() -> DaemonStatusModel:
return DaemonStatusModel(running=True, num_workers=response["numprocesses"])
-@router.post("/api/daemon/stop", response_model=DaemonStatusModel)
+@router.post("/api/daemon/task/stop", response_model=DaemonStatusModel)
@with_dbenv()
async def get_daemon_stop() -> DaemonStatusModel:
"""Stop the daemon."""
@@ -86,7 +86,7 @@ async def get_daemon_stop() -> DaemonStatusModel:
return DaemonStatusModel(running=False, num_workers=None)
-@router.post("/api/daemon/increase", response_model=DaemonStatusModel)
+@router.post("/api/daemon/task/increase", response_model=DaemonStatusModel)
@with_dbenv()
async def increase_daemon_worker() -> DaemonStatusModel:
"""increase the daemon worker."""
@@ -103,7 +103,7 @@ async def increase_daemon_worker() -> DaemonStatusModel:
return DaemonStatusModel(running=False, num_workers=None)
-@router.post("/api/daemon/decrease", response_model=DaemonStatusModel)
+@router.post("/api/daemon/task/decrease", response_model=DaemonStatusModel)
@with_dbenv()
async def decrease_daemon_worker() -> DaemonStatusModel:
"""decrease the daemon worker."""
diff --git a/aiida_workgraph/web/backend/app/scheduler.py b/aiida_workgraph/web/backend/app/scheduler.py
new file mode 100644
index 00000000..c4110402
--- /dev/null
+++ b/aiida_workgraph/web/backend/app/scheduler.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""Declaration of FastAPI router for daemon endpoints."""
+from __future__ import annotations
+
+import typing as t
+
+from aiida.cmdline.utils.decorators import with_dbenv
+from aiida.engine.daemon.client import DaemonException
+from aiida_workgraph.engine.scheduler.client import get_scheduler_client
+from fastapi import APIRouter, HTTPException
+from pydantic import BaseModel, Field
+from aiida_workgraph.engine.scheduler.client import start_scheduler_process
+
+
+router = APIRouter()
+
+
+class DaemonStatusModel(BaseModel):
+ """Response model for daemon status."""
+
+ running: bool = Field(description="Whether the daemon is running or not.")
+ num_workers: t.Optional[int] = Field(
+ description="The number of workers if the daemon is running."
+ )
+
+
+@router.get("/api/daemon/scheduler/status", response_model=DaemonStatusModel)
+@with_dbenv()
+async def get_daemon_status() -> DaemonStatusModel:
+ """Return the daemon status."""
+ client = get_scheduler_client()
+
+ if not client.is_daemon_running:
+ return DaemonStatusModel(running=False, num_workers=None)
+
+ response = client.get_numprocesses()
+
+ return DaemonStatusModel(running=True, num_workers=response["numprocesses"])
+
+
+@router.get("/api/daemon/scheduler/worker")
+@with_dbenv()
+async def get_daemon_worker():
+ """Return the daemon status."""
+ client = get_scheduler_client()
+
+ if not client.is_daemon_running:
+ return {}
+
+ response = client.get_worker_info()
+
+ return response["info"]
+
+
+@router.post("/api/daemon/scheduler/start", response_model=DaemonStatusModel)
+@with_dbenv()
+async def get_daemon_start() -> DaemonStatusModel:
+ """Start the daemon."""
+ client = get_scheduler_client()
+
+ if client.is_daemon_running:
+ raise HTTPException(status_code=400, detail="The daemon is already running.")
+
+ try:
+ client.start_daemon()
+ except DaemonException as exception:
+ raise HTTPException(status_code=500, detail=str(exception)) from exception
+
+ response = client.get_numprocesses()
+ start_scheduler_process()
+
+ return DaemonStatusModel(running=True, num_workers=response["numprocesses"])
+
+
+@router.post("/api/daemon/scheduler/stop", response_model=DaemonStatusModel)
+@with_dbenv()
+async def get_daemon_stop() -> DaemonStatusModel:
+ """Stop the daemon."""
+ client = get_scheduler_client()
+
+ if not client.is_daemon_running:
+ raise HTTPException(status_code=400, detail="The daemon is not running.")
+
+ try:
+ client.stop_daemon()
+ except DaemonException as exception:
+ raise HTTPException(status_code=500, detail=str(exception)) from exception
+
+ return DaemonStatusModel(running=False, num_workers=None)
+
+
+@router.post("/api/daemon/scheduler/increase", response_model=DaemonStatusModel)
+@with_dbenv()
+async def increase_daemon_worker() -> DaemonStatusModel:
+ """increase the daemon worker."""
+ client = get_scheduler_client()
+
+ if not client.is_daemon_running:
+ raise HTTPException(status_code=400, detail="The daemon is not running.")
+
+ try:
+ client.increase_workers(1)
+ except DaemonException as exception:
+ raise HTTPException(status_code=500, detail=str(exception)) from exception
+
+ response = client.get_numprocesses()
+ print(response)
+ start_scheduler_process(response["numprocesses"])
+
+ return DaemonStatusModel(running=False, num_workers=None)
+
+
+@router.post("/api/daemon/scheduler/decrease", response_model=DaemonStatusModel)
+@with_dbenv()
+async def decrease_daemon_worker() -> DaemonStatusModel:
+ """decrease the daemon worker."""
+ client = get_scheduler_client()
+
+ if not client.is_daemon_running:
+ raise HTTPException(status_code=400, detail="The daemon is not running.")
+
+ try:
+ client.decrease_workers(1)
+ except DaemonException as exception:
+ raise HTTPException(status_code=500, detail=str(exception)) from exception
+
+ return DaemonStatusModel(running=False, num_workers=None)
diff --git a/aiida_workgraph/web/frontend/src/components/Settings.js b/aiida_workgraph/web/frontend/src/components/Settings.js
index 23310618..50a89fee 100644
--- a/aiida_workgraph/web/frontend/src/components/Settings.js
+++ b/aiida_workgraph/web/frontend/src/components/Settings.js
@@ -3,79 +3,131 @@ import { ToastContainer, toast } from 'react-toastify';
import 'react-toastify/dist/ReactToastify.css';
function Settings() {
- const [workers, setWorkers] = useState([]);
+ const [taskWorkers, setTaskWorkers] = useState([]);
+ const [schedulerWorkers, setSchedulerWorkers] = useState([]);
- const fetchWorkers = () => {
- fetch('http://localhost:8000/api/daemon/worker')
+ // Fetching task workers
+ const fetchTaskWorkers = () => {
+ fetch('http://localhost:8000/api/daemon/task/worker')
.then(response => response.json())
- .then(data => setWorkers(Object.values(data)))
- .catch(error => console.error('Failed to fetch workers:', error));
+ .then(data => setTaskWorkers(Object.values(data)))
+ .catch(error => console.error('Failed to fetch task workers:', error));
+ };
+
+ // Fetching scheduler workers
+ const fetchSchedulerWorkers = () => {
+ fetch('http://localhost:8000/api/daemon/scheduler/worker')
+ .then(response => response.json())
+ .then(data => setSchedulerWorkers(Object.values(data)))
+ .catch(error => console.error('Failed to fetch scheduler workers:', error));
};
useEffect(() => {
- fetchWorkers();
- const interval = setInterval(fetchWorkers, 1000); // Poll every 5 seconds
- return () => clearInterval(interval); // Clear interval on component unmount
+ fetchTaskWorkers();
+ fetchSchedulerWorkers();
+ const taskInterval = setInterval(fetchTaskWorkers, 1000);
+ const schedulerInterval = setInterval(fetchSchedulerWorkers, 1000);
+ return () => {
+ clearInterval(taskInterval);
+ clearInterval(schedulerInterval);
+ }; // Clear intervals on component unmount
}, []);
- const handleDaemonControl = (action) => {
- fetch(`http://localhost:8000/api/daemon/${action}`, { method: 'POST' })
+ const handleDaemonControl = (daemonType, action) => {
+ fetch(`http://localhost:8000/api/daemon/${daemonType}/${action}`, { method: 'POST' })
.then(response => {
if (!response.ok) {
- throw new Error(`Daemon operation failed: ${response.statusText}`);
+ throw new Error(`${daemonType} daemon operation failed: ${response.statusText}`);
}
return response.json();
})
- .then(data => {
- toast.success(`Daemon ${action}ed successfully`);
- fetchWorkers();
+ .then(() => {
+ toast.success(`${daemonType} daemon ${action}ed successfully`);
+ if (daemonType === 'task') {
+ fetchTaskWorkers();
+ } else {
+ fetchSchedulerWorkers();
+ }
})
.catch(error => toast.error(error.message));
};
- const adjustWorkers = (action) => {
- fetch(`http://localhost:8000/api/daemon/${action}`, { method: 'POST' })
+ const adjustWorkers = (daemonType, action) => {
+ fetch(`http://localhost:8000/api/daemon/${daemonType}/${action}`, { method: 'POST' })
.then(response => {
if (!response.ok) {
- throw new Error(`Failed to ${action} workers: ${response.statusText}`);
+ throw new Error(`Failed to ${action} workers for ${daemonType}: ${response.statusText}`);
}
return response.json();
})
- .then(data => {
- toast.success(`Workers ${action}ed successfully`);
- fetchWorkers(); // Refetch workers after adjusting
+ .then(() => {
+ toast.success(`${daemonType} Workers ${action}ed successfully`);
+ if (daemonType === 'task') {
+ fetchTaskWorkers();
+ } else {
+ fetchSchedulerWorkers();
+ }
})
.catch(error => toast.error(error.message));
};
return (
-
Daemon Control
-
-
-
- PID
- Memory %
- CPU %
- Started
-
-
-
- {workers.map(worker => (
-
- {worker.pid}
- {worker.mem}
- {worker.cpu}
- {new Date(worker.started * 1000).toLocaleString()}
+
+
Task Daemon Control
+
handleDaemonControl('task', 'start')}>Start Task Daemon
+
handleDaemonControl('task', 'stop')}>Stop Task Daemon
+
adjustWorkers('task', 'increase')}>Increase Task Workers
+
adjustWorkers('task', 'decrease')}>Decrease Task Workers
+
+
+
+ PID
+ Memory %
+ CPU %
+ Started
+
+
+
+ {taskWorkers.map(worker => (
+
+ {worker.pid}
+ {worker.mem}
+ {worker.cpu}
+ {new Date(worker.started * 1000).toLocaleString()}
+
+ ))}
+
+
+
+
+
Scheduler Daemon Control
+
handleDaemonControl('scheduler', 'start')}>Start Scheduler Daemon
+
handleDaemonControl('scheduler', 'stop')}>Stop Scheduler Daemon
+
adjustWorkers('scheduler', 'increase')}>Increase Scheduler Workers
+
adjustWorkers('scheduler', 'decrease')}>Decrease Scheduler Workers
+
+
+
+ PID
+ Memory %
+ CPU %
+ Started
- ))}
-
-
-
handleDaemonControl('start')}>Start Daemon
-
handleDaemonControl('stop')}>Stop Daemon
-
adjustWorkers('increase')}>Increase Workers
-
adjustWorkers('decrease')}>Decrease Workers
+
+
+ {schedulerWorkers.map(worker => (
+
+ {worker.pid}
+ {worker.mem}
+ {worker.cpu}
+ {new Date(worker.started * 1000).toLocaleString()}
+
+ ))}
+
+
+
);
}
diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py
index 339435a0..e99a34f6 100644
--- a/aiida_workgraph/workgraph.py
+++ b/aiida_workgraph/workgraph.py
@@ -115,6 +115,7 @@ def submit(
wait: bool = False,
timeout: int = 60,
metadata: Optional[Dict[str, Any]] = None,
+ to_scheduler: bool = False,
) -> aiida.orm.ProcessNode:
"""Submit the AiiDA workgraph process and optionally wait for it to finish.
Args:
@@ -134,27 +135,34 @@ def submit(
self.save(metadata=metadata)
if self.process.process_state.value.upper() not in ["CREATED"]:
raise ValueError(f"Process {self.process.pk} has already been submitted.")
- self.continue_process()
+ if to_scheduler:
+ self.continue_process_in_scheduler(to_scheduler)
+ else:
+ self.continue_process()
# as long as we submit the process, it is a new submission, we should set restart_process to None
self.restart_process = None
if wait:
self.wait(timeout=timeout)
return self.process
- def save(self, metadata: Optional[Dict[str, Any]] = None) -> None:
+ def save(
+ self, metadata: Optional[Dict[str, Any]] = None, parent_pid: int = None
+ ) -> None:
"""Save the udpated workgraph to the process
This is only used for a running workgraph.
Save the AiiDA workgraph process and update the process status.
"""
from aiida.manage import manager
- from aiida.engine.utils import instantiate_process
+ from aiida_workgraph.engine.utils import instantiate_process
from aiida_workgraph.engine.workgraph import WorkGraphEngine
inputs = self.prepare_inputs(metadata)
if self.process is None:
runner = manager.get_manager().get_runner()
# init a process node
- process_inited = instantiate_process(runner, WorkGraphEngine, **inputs)
+ process_inited = instantiate_process(
+ runner, WorkGraphEngine, parent_pid=parent_pid, **inputs
+ )
process_inited.runner.persister.save_checkpoint(process_inited)
self.process = process_inited.node
self.process_inited = process_inited
@@ -415,6 +423,24 @@ def continue_process(self):
process_controller = get_manager().get_process_controller()
process_controller.continue_process(self.pk)
+ def continue_process_in_scheduler(self, to_scheduler: Union[int, bool]) -> None:
+ """Ask the scheduler to pick up the process from the database and run it.
+ If to_scheduler is an integer, it will be used as the scheduler pk.
+ Otherwise, it will send the message to the queue, and the scheduler will pick it up.
+ """
+ from aiida_workgraph.utils.control import (
+ create_task_action,
+ create_workgraph_action,
+ )
+
+ try:
+ if isinstance(to_scheduler, int) and not isinstance(to_scheduler, bool):
+ create_task_action(to_scheduler, [self.pk], action="launch_workgraph")
+ else:
+ create_workgraph_action(self.pk)
+ except Exception as e:
+ print("""An unexpected error occurred:""", e)
+
def play(self):
import os
diff --git a/docs/source/howto/index.rst b/docs/source/howto/index.rst
index 274f74be..3991edfe 100644
--- a/docs/source/howto/index.rst
+++ b/docs/source/howto/index.rst
@@ -24,5 +24,6 @@ This section contains a collection of HowTos for various topics.
queue
cli
control
+ scheduler
transfer_workchain
workchain_call_workgraph
diff --git a/docs/source/howto/scheduler.ipynb b/docs/source/howto/scheduler.ipynb
new file mode 100644
index 00000000..e88f2f50
--- /dev/null
+++ b/docs/source/howto/scheduler.ipynb
@@ -0,0 +1,358 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Scheduler\n",
+ "\n",
+ "## Overview\n",
+ "\n",
+ "This documentation provides a guide on using the `aiida-workgraph` Scheduler to manage `WorkGraph` processes efficiently.\n",
+ "\n",
+ "### Background\n",
+ "\n",
+ "Traditional workflow processes, particularly in nested structures like `PwBandsWorkChain`, tend to create multiple Workflow processes in a waiting state, while only a few `CalcJob` processes run actively. This results in inefficient resource usage. The `WorkChain` structure makes it challenging to eliminate these waiting processes due to its encapsulated logic.\n",
+ "\n",
+ "In contrast, the `WorkGraph` offers a more clear task dependency and allow other process to run its tasks in a controllable way. In a scheduler, one only need create the `WorkGraph` process in the database, not run it via a daemon worker.\n",
+ "\n",
+ "### Process Comparison: `PwBands` Case\n",
+ "\n",
+ "- **Old Approach**: 300 Workflow processes (Bands, Relax, Base) + 100 CalcJob processes.\n",
+ "- **New Approach**: 1 Scheduler process + 100 CalcJob processes.\n",
+ "\n",
+ "This new approach significantly reduces the number of active processes and mitigates the risk of deadlocks.\n",
+ "\n",
+ "## Getting Started with the Scheduler\n",
+ "\n",
+ "### Starting the Scheduler\n",
+ "\n",
+ "To launch a scheduler daemon:\n",
+ "\n",
+ "```console\n",
+ "workgraph scheduler start\n",
+ "```\n",
+ "\n",
+ "### Monitoring the Scheduler\n",
+ "\n",
+ "To check the current status of the scheduler:\n",
+ "\n",
+ "```console\n",
+ "workgraph scheduler status\n",
+ "```\n",
+ "\n",
+ "### Stopping the Scheduler\n",
+ "\n",
+ "To stop the scheduler daemon:\n",
+ "\n",
+ "```console\n",
+ "workgraph scheduler stop\n",
+ "```\n",
+ "\n",
+ "## Submitting WorkGraphs to the Scheduler\n",
+ "\n",
+ "To submit a WorkGraph to the scheduler, set the `to_scheduler` flag to `True`:\n",
+ "\n",
+ "```python\n",
+ "wg.submit(to_scheduler=True)\n",
+ "```\n",
+ "\n",
+ "\n",
+ "### Using Multiple Schedulers\n",
+ "\n",
+ "For environments with a high volume of WorkGraphs, starting multiple schedulers can enhance throughput:\n",
+ "\n",
+ "```console\n",
+ "workgraph scheduler start 2\n",
+ "```\n",
+ "\n",
+ "WorkGraphs will be automatically distributed among available schedulers.\n",
+ "\n",
+ "#### Specifying a Scheduler\n",
+ "\n",
+ "To submit a WorkGraph to a specific scheduler using its primary key (`pk`):\n",
+ "\n",
+ "```python\n",
+ "wg.submit(to_scheduler=pk_scheduler)\n",
+ "```\n",
+ "\n",
+ "### Best Practices for Scheduler Usage\n",
+ "\n",
+ "While a single scheduler suffices for most use cases, scaling up the number of schedulers may be beneficial when significantly increasing the number of task workers (created by `verdi daemon start`). A general rule is to maintain a ratio of less than 5 workers per scheduler.\n",
+ "\n",
+ "## Example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WorkGraph process created, PK: 142617\n",
+ "State of WorkGraph : FINISHED\n",
+ "Result of add2 : 4\n"
+ ]
+ }
+ ],
+ "source": [
+ "from aiida_workgraph import WorkGraph, task\n",
+ "from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n",
+ "from aiida import load_profile, orm\n",
+ "\n",
+ "load_profile()\n",
+ "\n",
+ "@task.calcfunction()\n",
+ "def add(x, y):\n",
+ " return x + y\n",
+ "\n",
+ "code = orm.load_code(\"add@localhost\")\n",
+ "\n",
+ "wg = WorkGraph(f\"test_scheduler\")\n",
+ "add1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\", x=1, y=2, code=code)\n",
+ "add2 = wg.add_task(ArithmeticAddCalculation, name=\"add2\", x=1, y=add1.outputs[\"sum\"], code=code)\n",
+ "wg.submit(to_scheduler=True,\n",
+ " wait=True)\n",
+ "print(\"State of WorkGraph : {}\".format(wg.state))\n",
+ "print('Result of add2 : {}'.format(wg.tasks[\"add2\"].node.outputs.sum.value))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate node graph from the AiiDA process,and we can see the provenance graph of the workgraph:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ "\n",
+ "N142617 \n",
+ " \n",
+ "WorkGraph<test_scheduler> (142617) \n",
+ "State: finished \n",
+ "Exit Code: 0 \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142620 \n",
+ " \n",
+ "ArithmeticAddCalculation (142620) \n",
+ "State: finished \n",
+ "Exit Code: 0 \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142617->N142620 \n",
+ " \n",
+ " \n",
+ "CALL_CALC \n",
+ "add1 \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142625 \n",
+ " \n",
+ "ArithmeticAddCalculation (142625) \n",
+ "State: finished \n",
+ "Exit Code: 0 \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142617->N142625 \n",
+ " \n",
+ " \n",
+ "CALL_CALC \n",
+ "add2 \n",
+ " \n",
+ "\n",
+ "\n",
+ "N37 \n",
+ "\n",
+ "InstalledCode (37) \n",
+ "add@localhost \n",
+ " \n",
+ "\n",
+ "\n",
+ "N37->N142617 \n",
+ " \n",
+ " \n",
+ "INPUT_WORK \n",
+ "wg__tasks__add1__properties__code__value \n",
+ " \n",
+ "\n",
+ "\n",
+ "N37->N142617 \n",
+ " \n",
+ " \n",
+ "INPUT_WORK \n",
+ "wg__tasks__add2__properties__code__value \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142621 \n",
+ "\n",
+ "RemoteData (142621) \n",
+ "@localhost \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142620->N142621 \n",
+ " \n",
+ " \n",
+ "CREATE \n",
+ "remote_folder \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142622 \n",
+ "\n",
+ "FolderData (142622) \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142620->N142622 \n",
+ " \n",
+ " \n",
+ "CREATE \n",
+ "retrieved \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142623 \n",
+ "\n",
+ "Int (142623) \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142620->N142623 \n",
+ " \n",
+ " \n",
+ "CREATE \n",
+ "sum \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142623->N142625 \n",
+ " \n",
+ " \n",
+ "INPUT_CALC \n",
+ "y \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142626 \n",
+ "\n",
+ "RemoteData (142626) \n",
+ "@localhost \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142625->N142626 \n",
+ " \n",
+ " \n",
+ "CREATE \n",
+ "remote_folder \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142627 \n",
+ "\n",
+ "FolderData (142627) \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142625->N142627 \n",
+ " \n",
+ " \n",
+ "CREATE \n",
+ "retrieved \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142628 \n",
+ "\n",
+ "Int (142628) \n",
+ " \n",
+ "\n",
+ "\n",
+ "N142625->N142628 \n",
+ " \n",
+ " \n",
+ "CREATE \n",
+ "sum \n",
+ " \n",
+ " \n",
+ " \n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from aiida_workgraph.utils import generate_node_graph\n",
+ "generate_node_graph(wg.pk)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "## Checkpointing\n",
+ "\n",
+ "The Scheduler checkpoints its status to the database whenever a WorkGraph is updated, ensuring that the Scheduler can recover its state in case of a crash or restart. This feature is particularly useful for long-running WorkGraphs.\n",
+ "\n",
+ "## Conclusion\n",
+ "\n",
+ "The Scheduler offers a streamlined approach to managing complex workflows, significantly reducing active process counts and improving resource efficiency."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "aiida",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/pyproject.toml b/pyproject.toml
index b67eee61..4f76b8aa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -83,6 +83,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"
[project.entry-points.'aiida.workflows']
"workgraph.engine" = "aiida_workgraph.engine.workgraph:WorkGraphEngine"
+"workgraph.scheduler" = "aiida_workgraph.engine.scheduler.scheduler:WorkGraphScheduler"
[project.entry-points."aiida.data"]
"workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData"
diff --git a/tests/conftest.py b/tests/conftest.py
index 36b1fe61..efb2d1f0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -25,6 +25,46 @@ def fixture_localhost(aiida_localhost):
return localhost
+@pytest.fixture(scope="session")
+def scheduler_client(aiida_profile):
+ """Return a daemon client for the configured test profile for the test session.
+
+ The daemon will be automatically stopped at the end of the test session.
+ """
+ from aiida_workgraph.engine.scheduler.client import SchedulerClient
+ from aiida.engine.daemon.client import (
+ DaemonNotRunningException,
+ DaemonTimeoutException,
+ )
+
+ scheduler_client = SchedulerClient(aiida_profile)
+
+ try:
+ yield scheduler_client
+ finally:
+ try:
+ scheduler_client.stop_daemon(wait=True)
+ except DaemonNotRunningException:
+ pass
+ # Give an additional grace period by manually waiting for the daemon to be stopped. In certain unit test
+ # scenarios, the built in wait time in ``scheduler_client.stop_daemon`` is not sufficient and even though the
+ # daemon is stopped, ``scheduler_client.is_daemon_running`` will return false for a little bit longer.
+ scheduler_client._await_condition(
+ lambda: not scheduler_client.is_daemon_running,
+ DaemonTimeoutException("The daemon failed to stop."),
+ )
+
+
+@pytest.fixture()
+def started_scheduler_client(scheduler_client):
+ """Ensure that the daemon is running for the test profile and return the associated client."""
+ if not scheduler_client.is_daemon_running:
+ scheduler_client.start_daemon()
+ assert scheduler_client.is_daemon_running
+
+ yield scheduler_client
+
+
@pytest.fixture
def add_code(fixture_localhost):
from aiida.orm import InstalledCode
diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py
new file mode 100644
index 00000000..c6308551
--- /dev/null
+++ b/tests/test_scheduler.py
@@ -0,0 +1,21 @@
+import pytest
+from typing import Callable
+from aiida_workgraph import WorkGraph
+from aiida.cmdline.utils.common import get_workchain_report
+from aiida_workgraph.engine.scheduler.client import get_scheduler
+from aiida import orm
+
+
+@pytest.mark.skip("Skip for now")
+@pytest.mark.usefixtures("started_daemon_client")
+def test_scheduler(decorated_add: Callable, started_scheduler_client) -> None:
+ """Test graph build."""
+ wg = WorkGraph("test_scheduler")
+ add1 = wg.add_task(decorated_add, x=2, y=3)
+ add2 = wg.add_task(decorated_add, "add2", x=3, y=add1.outputs["result"])
+ # use run to check if graph builder workgraph can be submit inside the engine
+ wg.submit(to_scheduler=True, wait=True)
+ pk = get_scheduler()
+ report = get_workchain_report(orm.load(pk), "REPORT")
+ print("report: ", report)
+ assert add2.outputs["result"].value == 8