From b298b912f3e7dbfbfe18a9394db3b95b1006ac4e Mon Sep 17 00:00:00 2001 From: superstar54 Date: Mon, 26 Aug 2024 10:31:13 +0200 Subject: [PATCH 01/14] add scheduler --- aiida_workgraph/engine/scheduler.py | 1523 +++++++++++++++++++++++++++ aiida_workgraph/workgraph.py | 8 +- pyproject.toml | 1 + 3 files changed, 1531 insertions(+), 1 deletion(-) create mode 100644 aiida_workgraph/engine/scheduler.py diff --git a/aiida_workgraph/engine/scheduler.py b/aiida_workgraph/engine/scheduler.py new file mode 100644 index 00000000..04201860 --- /dev/null +++ b/aiida_workgraph/engine/scheduler.py @@ -0,0 +1,1523 @@ +"""AiiDA workflow components: WorkGraph.""" +from __future__ import annotations + +import asyncio +import collections.abc +import functools +import logging +import typing as t + +from plumpy import process_comms +from plumpy.persistence import auto_persist +from plumpy.process_states import Continue, Wait +import kiwipy + +from aiida.common import exceptions +from aiida.common.extendeddicts import AttributeDict +from aiida.common.lang import override +from aiida import orm +from aiida.orm import load_node, Node, ProcessNode, WorkChainNode +from aiida.orm.utils.serialize import deserialize_unsafe, serialize + +from aiida.engine.processes.exit_code import ExitCode +from aiida.engine.processes.process import Process + +from aiida.engine.processes.workchains.awaitable import ( + Awaitable, + AwaitableAction, + AwaitableTarget, + construct_awaitable, +) +from aiida.engine.processes.workchains.workchain import Protect, WorkChainSpec +from aiida.engine import run_get_node +from aiida_workgraph.utils import create_and_pause_process +from aiida_workgraph.task import Task +from aiida_workgraph.utils import get_nested_dict, update_nested_dict +from aiida_workgraph.executors.monitors import monitor + +if t.TYPE_CHECKING: + from aiida.engine.runners import Runner # pylint: disable=unused-import + +__all__ = "WorkGraph" + + +MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}. Cannot launch the job: {}." + + +@auto_persist("_awaitables") +class WorkGraphScheduler(Process, metaclass=Protect): + """The `WorkGraph` class is used to construct workflows in AiiDA.""" + + # used to create a process node that represents what happened in this process. + _node_class = WorkChainNode + _spec_class = WorkChainSpec + _CONTEXT = "CONTEXT" + + def __init__( + self, + inputs: dict | None = None, + logger: logging.Logger | None = None, + runner: "Runner" | None = None, + enable_persistence: bool = True, + ) -> None: + """Construct a WorkGraph instance. + + :param inputs: work graph inputs + :param logger: aiida logger + :param runner: work graph runner + :param enable_persistence: whether to persist this work graph + + """ + + super().__init__(inputs, logger, runner, enable_persistence=enable_persistence) + + self._awaitables: list[Awaitable] = [] + self._context = AttributeDict() + + @classmethod + def define(cls, spec: WorkChainSpec) -> None: + super().define(spec) + spec.input("input_file", valid_type=orm.SinglefileData, required=False) + spec.input_namespace( + "wg", dynamic=True, required=False, help="WorkGraph inputs" + ) + spec.input_namespace("input_tasks", dynamic=True, required=False) + spec.exit_code(2, "ERROR_SUBPROCESS", message="A subprocess has failed.") + + spec.outputs.dynamic = True + + spec.output_namespace("new_data", dynamic=True) + spec.output( + "execution_count", + valid_type=orm.Int, + required=False, + help="The number of time the WorkGraph runs.", + ) + # + spec.exit_code( + 201, "UNKNOWN_MESSAGE_TYPE", message="The message type is unknown." + ) + spec.exit_code(202, "UNKNOWN_TASK_TYPE", message="The task type is unknown.") + # + spec.exit_code( + 301, + "OUTPUS_NOT_MATCH_RESULTS", + message="The outputs of the process do not match the results.", + ) + spec.exit_code( + 302, + "TASK_FAILED", + message="Some of the tasks failed.", + ) + spec.exit_code( + 303, + "TASK_NON_ZERO_EXIT_STATUS", + message="Some of the tasks exited with non-zero status.", + ) + + @property + def ctx(self) -> AttributeDict: + """Get the context.""" + return self._context + + @override + def save_instance_state( + self, out_state: t.Dict[str, t.Any], save_context: t.Any + ) -> None: + """Save instance state. + + :param out_state: state to save in + + :param save_context: + :type save_context: :class:`!plumpy.persistence.LoadSaveContext` + + """ + super().save_instance_state(out_state, save_context) + # Save the context + out_state[self._CONTEXT] = self.ctx + + @override + def load_instance_state( + self, saved_state: t.Dict[str, t.Any], load_context: t.Any + ) -> None: + super().load_instance_state(saved_state, load_context) + # Load the context + self._context = saved_state[self._CONTEXT] + self._temp = {"awaitables": {}} + + self.set_logger(self.node.logger) + + if self._awaitables: + # For the "ascyncio.tasks.Task" awaitable, because there are only in-memory, + # we need to reset the tasks and so that they can be re-run again. + should_resume = False + for awaitable in self._awaitables: + if awaitable.target == "asyncio.tasks.Task": + self._resolve_awaitable(awaitable, None) + self.report(f"reset awaitable task: {awaitable.key}") + self.reset_task(awaitable.key) + should_resume = True + if should_resume: + self._update_process_status() + self.resume() + # For other awaitables, because they exist in the db, we only need to re-register the callbacks + self.ctx._workgraph[pk]["_awaitable_actions"] = [] + self._action_awaitables() + + def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: + """ + Returns a reference to a sub-dictionary of the context and the last key, + after resolving a potentially segmented key where required sub-dictionaries are created as needed. + + :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary + """ + ctx = self.ctx + ctx_path = key.split(".") + + for index, path in enumerate(ctx_path[:-1]): + try: + ctx = ctx[path] + except KeyError: # see below why this is the only exception we have to catch here + ctx[ + path + ] = AttributeDict() # create the sub-dict and update the context + ctx = ctx[path] + continue + + # Notes: + # * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking + # * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables + # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself + # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable + # would be an AttributeDict we can append things to it since the order of tasks is maintained. + if type(ctx) != AttributeDict: # pylint: disable=C0123 + raise ValueError( + f"Can not update the context for key `{key}`: " + f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index + 1])}`, expected AttributeDict' + ) + + return ctx, ctx_path[-1] + + def _insert_awaitable(self, awaitable: Awaitable) -> None: + """Insert an awaitable that should be terminated before before continuing to the next step. + + :param awaitable: the thing to await + """ + ctx, key = self._resolve_nested_context(awaitable.key) + + # Already assign the awaitable itself to the location in the context container where it is supposed to end up + # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the + # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the + # awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value. + if awaitable.action == AwaitableAction.ASSIGN: + ctx[key] = awaitable + elif awaitable.action == AwaitableAction.APPEND: + ctx.setdefault(key, []).append(awaitable) + else: + raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") + + self._awaitables.append( + awaitable + ) # add only if everything went ok, otherwise we end up in an inconsistent state + self._update_process_status() + + def _resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None: + """Resolve an awaitable. + + Precondition: must be an awaitable that was previously inserted. + + :param awaitable: the awaitable to resolve + :param value: the value to assign to the awaitable + """ + ctx, key = self._resolve_nested_context(awaitable.key) + + if awaitable.action == AwaitableAction.ASSIGN: + ctx[key] = value + elif awaitable.action == AwaitableAction.APPEND: + # Find the same awaitable inserted in the context + container = ctx[key] + for index, placeholder in enumerate(container): + if ( + isinstance(placeholder, Awaitable) + and placeholder.pk == awaitable.pk + ): + container[index] = value + break + else: + raise AssertionError( + f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`" + ) + else: + raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") + + awaitable.resolved = True + # remove awaitabble from the list + self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk] + + if not self.has_terminated(): + # the process may be terminated, for example, if the process was killed or excepted + # then we should not try to update it + self._update_process_status() + + @Protect.final + def to_context(self, **kwargs: Awaitable | ProcessNode) -> None: + """Add a dictionary of awaitables to the context. + + This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will + assign a certain value to the corresponding key in the context of the work graph. + """ + for key, value in kwargs.items(): + awaitable = construct_awaitable(value) + awaitable.key = key + self._insert_awaitable(awaitable) + + def _update_process_status(self) -> None: + """Set the process status with a message accounting the current sub processes that we are waiting for.""" + if self._awaitables: + status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}" + self.node.set_process_status(status) + else: + self.node.set_process_status(None) + + @override + def run(self) -> t.Any: + self.setup() + return self._do_step() + + def _do_step(self) -> t.Any: + """Execute the next step in the workgraph and return the result. + + If any awaitables were created, the process will enter in the Wait state, + otherwise it will go to Continue. + """ + # we will not remove the awaitables here, + # we resume the workgraph in the callback function even + # there are some awaitables left + # self._awaitables = [] + + if self._awaitables: + return Wait(self._do_step, "Waiting before next step") + + return Continue(self._do_step) + + def _store_nodes(self, data: t.Any) -> None: + """Recurse through a data structure and store any unstored nodes that are found along the way + + :param data: a data structure potentially containing unstored nodes + """ + if isinstance(data, Node) and not data.is_stored: + data.store() + elif isinstance(data, collections.abc.Mapping): + for _, value in data.items(): + self._store_nodes(value) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str): + for value in data: + self._store_nodes(value) + + @override + @Protect.final + def on_exiting(self) -> None: + """Ensure that any unstored nodes in the context are stored, before the state is exited + + After the state is exited the next state will be entered and if persistence is enabled, a checkpoint will + be saved. If the context contains unstored nodes, the serialization necessary for checkpointing will fail. + """ + super().on_exiting() + try: + self._store_nodes(self.ctx) + except Exception: # pylint: disable=broad-except + # An uncaught exception here will have bizarre and disastrous consequences + self.logger.exception("exception in _store_nodes called in on_exiting") + + @Protect.final + def on_wait(self, awaitables: t.Sequence[t.Awaitable]): + """Entering the WAITING state.""" + super().on_wait(awaitables) + if self._awaitables: + self._action_awaitables() + self.report("Process status: {}".format(self.node.process_status)) + else: + self.call_soon(self.resume) + + def _action_awaitables(self) -> None: + """Handle the awaitables that are currently registered with the work chain. + + Depending on the class type of the awaitable's target a different callback + function will be bound with the awaitable and the runner will be asked to + call it when the target is completed + """ + for awaitable in self._awaitables: + pk = awaitable.workgraph_pk + # if the waitable already has a callback, skip + if awaitable.pk in self.ctx._workgraph[pk]["_awaitable_actions"]: + continue + if awaitable.target == AwaitableTarget.PROCESS: + callback = functools.partial( + self.call_soon, self._on_awaitable_finished, awaitable + ) + self.runner.call_on_process_finish(awaitable.pk, callback) + self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) + elif awaitable.target == "asyncio.tasks.Task": + # this is a awaitable task, the callback function is already set + self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) + else: + assert f"invalid awaitable target '{awaitable.target}'" + + def _on_awaitable_finished(self, awaitable: Awaitable) -> None: + """Callback function, for when an awaitable process instance is completed. + + The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all + awaitables have been dealt with, the work chain process is resumed. + + :param awaitable: an Awaitable instance + """ + self.logger.debug(f"Awaitable {awaitable.key} finished.") + pk = awaitable.workgraph_pk + + if isinstance(awaitable.pk, int): + self.logger.info( + "received callback that awaitable with key {} and pk {} has terminated".format( + awaitable.key, awaitable.pk + ) + ) + try: + node = load_node(awaitable.pk) + except (exceptions.MultipleObjectsError, exceptions.NotExistent): + raise ValueError( + f"provided pk<{awaitable.pk}> could not be resolved to a valid Node instance" + ) + + if awaitable.outputs: + value = { + entry.link_label: entry.node + for entry in node.base.links.get_outgoing() + } + else: + value = node # type: ignore + else: + # In this case, the pk and key are the same. + self.logger.info( + "received callback that awaitable {} has terminated".format( + awaitable.key + ) + ) + try: + # if awaitable is cancelled, the result is None + if awaitable.cancelled(): + self.set_task_state_info(pk, awaitable.key, "state", "KILLED") + # set child tasks state to SKIPPED + self.set_tasks_state( + self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED" + ) + self.report(f"Task: {awaitable.key} cancelled.") + else: + results = awaitable.result() + self.set_normal_task_results(awaitable.key, results) + except Exception as e: + self.logger.error(f"Error in awaitable {awaitable.key}: {e}") + self.set_task_state_info(pk, awaitable.key, "state", "FAILED") + # set child tasks state to SKIPPED + self.set_tasks_state( + self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED" + ) + self.report(f"Task: {awaitable.key} failed.") + self.run_error_handlers(pk, awaitable.key) + value = None + + self._resolve_awaitable(awaitable, value) + + # node finished, update the task state and result + # udpate the task state + self.update_task_state(awaitable.workgraph_pk, awaitable.key) + # try to resume the workgraph, if the workgraph is already resumed + # by other awaitable, this will not work + try: + self.resume() + except Exception as e: + print(e) + + def _build_process_label(self) -> str: + """Use the workgraph name as the process label.""" + return "Scheduler" + + def on_create(self) -> None: + """Called when a Process is created.""" + + super().on_create() + self.node.label = "Scheduler" + + def setup(self) -> None: + """Setup the variables in the context.""" + # track if the awaitable callback is added to the runner + + self.ctx._workgraph = {} + awaitable = Awaitable( + **{ + "workgraph_pk": self.node.pk, + "pk": "scheduler", + "action": AwaitableAction.ASSIGN, + "target": "scheduler", + "outputs": False, + } + ) + self.ctx._workgraph[self.node.pk] = {"_awaitable_actions": []} + self.to_context(scheduler=awaitable) + # self.ctx._msgs = [] + # self.ctx._execution_count = {} + # data not to be persisted, because they are not serializable + self._temp = {"awaitables": {}} + + def launch_workgraph(self, pk: str) -> None: + """Launch the workgraph.""" + # create the workgraph process + self.init_ctx_workgraph(pk) + self.set_task_results(pk) + + def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None: + """setup the workgraph in the context.""" + import cloudpickle as pickle + + pk = wgdata["pk"] + self.ctx._workgraph[pk]["_tasks"] = wgdata["tasks"] + self.ctx._workgraph[pk]["_links"] = wgdata["links"] + self.ctx._workgraph[pk]["_connectivity"] = wgdata["connectivity"] + self.ctx._workgraph[pk]["_ctrl_links"] = wgdata["ctrl_links"] + self.ctx._workgraph[pk]["_workgraph"] = wgdata + self.ctx._workgraph[pk]["_error_handlers"] = pickle.loads( + wgdata["error_handlers"] + ) + self.ctx._workgraph[pk]["_awaitable_actions"] = [] + + def read_wgdata_from_base(self, pk: int) -> t.Dict[str, t.Any]: + """Read workgraph data from base.extras.""" + from aiida_workgraph.orm.function_data import PickledLocalFunction + + node = load_node(pk) + + wgdata = node.base.extras.get("_workgraph") + for name, task in wgdata["tasks"].items(): + wgdata["tasks"][name] = deserialize_unsafe(task) + for _, prop in wgdata["tasks"][name]["properties"].items(): + if isinstance(prop["value"], PickledLocalFunction): + prop["value"] = prop["value"].value + wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) + return wgdata, node + + def update_workgraph_from_base(self) -> None: + """Update the ctx from base.extras.""" + wgdata, _ = self.read_wgdata_from_base() + for name, task in wgdata["tasks"].items(): + task["results"] = self.ctx._workgraph[pk]["_tasks"][name].get("results") + self.setup_ctx_workgraph(wgdata) + + def get_task(self, name: str): + """Get task from the context.""" + task = Task.from_dict(self.ctx._workgraph[pk]["_tasks"][name]) + return task + + def update_task(self, task: Task): + """Update task in the context. + This is used in error handlers to update the task parameters.""" + self.ctx._workgraph[pk]["_tasks"][task.name][ + "properties" + ] = task.properties_to_dict() + self.reset_task(task.name) + + def get_task_state_info(self, pk: int, name: str, key: str) -> str: + """Get task state info from ctx.""" + + value = self.ctx._workgraph[pk]["_tasks"][name].get(key, None) + if key == "process" and value is not None: + value = deserialize_unsafe(value) + return value + + def set_task_state_info(self, pk: int, name: str, key: str, value: any) -> None: + """Set task state info to ctx and base.extras. + We task state to the base.extras, so that we can access outside the engine""" + + if key == "process": + value = serialize(value) + self.ctx._workgraph[pk]["_node"].base.extras.set( + f"_task_{key}_{name}", value + ) + else: + self.ctx._workgraph[pk]["_node"].base.extras.set( + f"_task_{key}_{name}", value + ) + self.ctx._workgraph[pk]["_tasks"][name][key] = value + + def init_ctx_workgraph(self, pk: int) -> None: + """Init the context from the workgraph data.""" + from aiida_workgraph.utils import update_nested_dict + + # read the latest workgraph data + wgdata, node = self.read_wgdata_from_base(pk) + self.ctx._workgraph[pk] = { + "_awaitable_actions": {}, + "_new_data": {}, + "_executed_tasks": {}, + "_count": 0, + "_context": {}, + "_node": node, + } + for key, value in wgdata["_context"].items(): + key = key.replace("__", ".") + update_nested_dict(self.ctx._workgraph[pk], key, value) + # set up the workgraph + self.setup_ctx_workgraph(wgdata) + + def set_task_results(self, pk) -> None: + for name, task in self.ctx._workgraph[pk]["_tasks"].items(): + if self.get_task_state_info(pk, name, "action").upper() == "RESET": + self.reset_task(pk, task["name"]) + self.update_task_state(pk, name) + + def apply_action(self, msg: dict) -> None: + + if msg["catalog"] == "task": + self.apply_task_actions(msg) + else: + self.report(f"Unknow message type {msg}") + + def apply_task_actions(self, msg: dict) -> None: + """Apply task actions to the workgraph.""" + action = msg["action"] + tasks = msg["tasks"] + self.report(f"Action: {action}. {tasks}") + if action.upper() == "RESET": + for name in tasks: + self.reset_task(name) + elif action.upper() == "PAUSE": + for name in tasks: + self.pause_task(name) + elif action.upper() == "PLAY": + for name in tasks: + self.play_task(name) + elif action.upper() == "SKIP": + for name in tasks: + self.skip_task(name) + elif action.upper() == "KILL": + for name in tasks: + self.kill_task(name) + elif action.upper() == "LAUNCH_WORKGRAPH": + for pk in tasks: + self.launch_workgraph(pk) + + def reset_task( + self, + name: str, + reset_process: bool = True, + recursive: bool = True, + reset_execution_count: bool = True, + ) -> None: + """Reset task state and remove it from the executed task. + If recursive is True, reset its child tasks.""" + + self.set_task_state_info(pk, name, "state", "PLANNED") + if reset_process: + self.set_task_state_info(pk, name, "process", None) + self.remove_executed_task(name) + # self.logger.debug(f"Task {name} action: RESET.") + # if the task is a while task, reset its child tasks + if ( + self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["node_type"].upper() + == "WHILE" + ): + if reset_execution_count: + self.ctx._workgraph[pk]["_tasks"][name]["execution_count"] = 0 + for child_task in self.ctx._workgraph[pk]["_tasks"][name]["children"]: + self.reset_task(child_task, reset_process=False, recursive=False) + if recursive: + # reset its child tasks + names = self.ctx._connectivity["child_node"][name] + for name in names: + self.reset_task(name, recursive=False) + + def pause_task(self, name: str) -> None: + """Pause task.""" + self.set_task_state_info(pk, name, "action", "PAUSE") + self.report(f"Task {name} action: PAUSE.") + + def play_task(self, name: str) -> None: + """Play task.""" + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} action: PLAY.") + + def skip_task(self, name: str) -> None: + """Skip task.""" + self.set_task_state_info(pk, name, "state", "SKIPPED") + self.report(f"Task {name} action: SKIP.") + + def kill_task(self, pk, name: str) -> None: + """Kill task. + This is used to kill the awaitable and monitor task. + """ + if self.get_task_state_info(pk, name, "state") in ["RUNNING"]: + if self.ctx._workgraph[pk]["_tasks"][name]["metadata"][ + "node_type" + ].upper() in [ + "AWAITABLE", + "MONITOR", + ]: + try: + self._temp["awaitables"][name].cancel() + self.set_task_state_info(pk, name, "state", "KILLED") + self.report(f"Task {name} action: KILLED.") + except Exception as e: + self.logger.error(f"Error in killing task {name}: {e}") + + def continue_workgraph(self, pk: int) -> None: + print("Continue workgraph.") + self.report("Continue workgraph.") + # self.update_workgraph_from_base() + task_to_run = [] + for name, task in self.ctx._workgraph[pk]["_tasks"].items(): + # update task state + if ( + self.get_task_state_info(pk, task["name"], "state") + in [ + "CREATED", + "RUNNING", + "FINISHED", + "FAILED", + "SKIPPED", + ] + or name in self.ctx._workgraph[pk]["_executed_tasks"] + ): + continue + ready, _ = self.is_task_ready_to_run(pk, name) + if ready: + task_to_run.append(name) + # + self.report("tasks ready to run: {}".format(",".join(task_to_run))) + self.run_tasks(pk, task_to_run) + + def update_task_state(self, pk: int, name: str) -> None: + """Update task state when the task is finished.""" + task = self.ctx._workgraph[pk]["_tasks"][name] + # print(f"set task result: {name}") + node = self.get_task_state_info(pk, name, "process") + if isinstance(node, orm.ProcessNode): + # print(f"set task result: {name} process") + state = node.process_state.value.upper() + if node.is_finished_ok: + self.set_task_state_info(pk, task["name"], "state", state) + if task["metadata"]["node_type"].upper() == "WORKGRAPH": + # expose the outputs of all the tasks in the workgraph + task["results"] = {} + outgoing = node.base.links.get_outgoing() + for link in outgoing.all(): + if isinstance(link.node, ProcessNode) and getattr( + link.node, "process_state", False + ): + task["results"][link.link_label] = link.node.outputs + else: + task["results"] = node.outputs + # self.ctx._new_data[name] = task["results"] + self.set_task_state_info(pk, task["name"], "state", "FINISHED") + self.task_set_context(pk, name) + self.report(f"Workgraph: {pk}, Task: {name} finished.") + # all other states are considered as failed + else: + task["results"] = node.outputs + # self.ctx._new_data[name] = task["results"] + self.set_task_state_info(pk, task["name"], "state", "FAILED") + # set child tasks state to SKIPPED + self.set_tasks_state( + self.ctx._workgraph["_connectivity"]["child_node"][name], "SKIPPED" + ) + self.report(f"Workgraph: {pk}, Task: {name} failed.") + self.run_error_handlers(pk, name) + elif isinstance(node, orm.Data): + task["results"] = {task["outputs"][0]["name"]: node} + self.set_task_state_info(pk, task["name"], "state", "FINISHED") + self.task_set_context(pk, name) + self.report(f"Workgraph: {pk}, Task: {name} finished.") + else: + task.setdefault("results", None) + + self.update_parent_task_state(pk, name) + + def set_normal_task_results(self, pk, name, results): + """Set the results of a normal task. + A normal task is created by decorating a function with @task(). + """ + task = self.ctx._workgraph[pk]["_tasks"][name] + if isinstance(results, tuple): + if len(task["outputs"]) != len(results): + return self.exit_codes.OUTPUS_NOT_MATCH_RESULTS + for i in range(len(task["outputs"])): + task["results"][task["outputs"][i]["name"]] = results[i] + elif isinstance(results, dict): + task["results"] = results + else: + task["results"][task["outputs"][0]["name"]] = results + self.task_set_context(pk, name) + self.set_task_state_info(pk, name, "state", "FINISHED") + self.report(f"Workgraph: {pk}, Task: {name} finished.") + self.update_parent_task_state(pk, name) + + def update_parent_task_state(self, name: str) -> None: + """Update parent task state.""" + parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"] + if parent_task[0]: + task_type = self.ctx._workgraph[pk]["_tasks"][parent_task[0]]["metadata"][ + "node_type" + ].upper() + if task_type == "WHILE": + self.update_while_task_state(parent_task[0]) + elif task_type == "IF": + self.update_zone_task_state(parent_task[0]) + elif task_type == "ZONE": + self.update_zone_task_state(parent_task[0]) + + def update_while_task_state(self, name: str) -> None: + """Update while task state.""" + finished, _ = self.are_childen_finished(name) + + if finished: + self.report( + f"Wihle Task {name}: this iteration finished. Try to reset for the next iteration." + ) + # reset the condition tasks + for input in self.ctx._workgraph[pk]["_tasks"][name]["inputs"]: + if input["name"].upper() == "CONDITIONS": + for link in input["links"]: + self.reset_task(link["from_node"], recursive=False) + # reset the task and all its children, so that the task can run again + # do not reset the execution count + self.reset_task(name, reset_execution_count=False) + + def update_zone_task_state(self, name: str) -> None: + """Update zone task state.""" + finished, _ = self.are_childen_finished(name) + if finished: + self.set_task_state_info(pk, name, "state", "FINISHED") + self.update_parent_task_state(pk, name) + self.report(f"Task: {name} finished.") + + def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: + """Check if the while task should run.""" + # check the conditions of the while task + not_excess_max_iterations = ( + self.ctx._workgraph[pk]["_tasks"][name]["execution_count"] + < self.ctx._workgraph[pk]["_tasks"][name]["properties"]["max_iterations"][ + "value" + ] + ) + conditions = [not_excess_max_iterations] + _, kwargs, _, _, _ = self.get_inputs(name) + if isinstance(kwargs["conditions"], list): + for condition in kwargs["conditions"]: + value = get_nested_dict(self.ctx, condition) + conditions.append(value) + elif isinstance(kwargs["conditions"], dict): + for _, value in kwargs["conditions"].items(): + conditions.append(value) + else: + conditions.append(kwargs["conditions"]) + return False not in conditions + + def should_run_if_task(self, name: str) -> tuple[bool, t.Any]: + """Check if the IF task should run.""" + _, kwargs, _, _, _ = self.get_inputs(name) + flag = kwargs["conditions"] + if kwargs["invert_condition"]: + return not flag + return flag + + def are_childen_finished(self, pk, name: str) -> tuple[bool, t.Any]: + """Check if the child tasks are finished.""" + task = self.ctx._workgraph[pk]["_tasks"][name] + finished = True + for name in task["children"]: + if self.get_task_state_info(pk, name, "state") not in [ + "FINISHED", + "SKIPPED", + "FAILED", + ]: + finished = False + break + return finished, None + + def run_error_handlers(self, pk, task_name: str) -> None: + """Run error handler.""" + node = self.get_task_state_info(task_name, "process") + if not node or not node.exit_status: + return + for _, data in self.ctx._error_handlers.items(): + if task_name in data["tasks"]: + handler = data["handler"] + metadata = data["tasks"][task_name] + if node.exit_code.status in metadata.get("exit_codes", []): + self.report(f"Run error handler: {metadata}") + metadata.setdefault("retry", 0) + if metadata["retry"] < metadata["max_retries"]: + handler(self, task_name, **metadata.get("kwargs", {})) + metadata["retry"] += 1 + + def is_workgraph_finished(self, pk) -> bool: + """Check if the workgraph is finished. + For `while` workgraph, we need check its conditions""" + is_finished = True + failed_tasks = [] + for name, task in self.ctx._workgraph[pk]["_tasks"].items(): + # self.update_task_state(name) + if self.get_task_state_info(pk, task["name"], "state") in [ + "RUNNING", + "CREATED", + "PLANNED", + "READY", + ]: + is_finished = False + elif self.get_task_state_info(pk, task["name"], "state") == "FAILED": + failed_tasks.append(name) + if is_finished: + if self.ctx._workgraph["workgraph_type"].upper() == "WHILE": + should_run = self.check_while_conditions(pk) + is_finished = not should_run + if self.ctx._workgraph["workgraph_type"].upper() == "FOR": + should_run = self.check_for_conditions(pk) + is_finished = not should_run + if is_finished and len(failed_tasks) > 0: + message = f"WorkGraph finished, but tasks: {failed_tasks} failed. Thus all their child tasks are skipped." + self.report(message) + result = ExitCode(302, message) + else: + result = None + return is_finished, result + + def check_while_conditions(self, pk: int) -> bool: + """Check while conditions. + Run all condition tasks and check if all the conditions are True. + """ + self.report("Check while conditions.") + if self.ctx._execution_count >= self.ctx._max_iteration: + self.report("Max iteration reached.") + return False + condition_tasks = [] + for c in self.ctx._workgraph["conditions"]: + task_name, socket_name = c.split(".") + if "task_name" != "context": + condition_tasks.append(task_name) + self.run_tasks(condition_tasks, continue_workgraph=False) + conditions = [] + for c in self.ctx._workgraph["conditions"]: + task_name, socket_name = c.split(".") + if task_name == "context": + conditions.append(self.ctx[socket_name]) + else: + conditions.append( + self.ctx._workgraph[pk]["_tasks"][task_name]["results"][socket_name] + ) + should_run = False not in conditions + if should_run: + self.reset() + self.set_tasks_state(condition_tasks, "SKIPPED") + return should_run + + def check_for_conditions(self, pk: int) -> bool: + condition_tasks = [c[0] for c in self.ctx._workgraph["conditions"]] + self.run_tasks(condition_tasks) + conditions = [self.ctx._count < len(self.ctx._sequence)] + [ + self.ctx._workgraph[pk]["_tasks"][c[0]]["results"][c[1]] + for c in self.ctx._workgraph["conditions"] + ] + should_run = False not in conditions + if should_run: + self.reset() + self.set_tasks_state(condition_tasks, "SKIPPED") + self.ctx["i"] = self.ctx._sequence[self.ctx._count] + self.ctx._count += 1 + return should_run + + def remove_executed_task(self, pk, name: str) -> None: + """Remove labels with name from executed tasks.""" + self.ctx._workgraph[pk]["_executed_tasks"] = [ + label + for label in self.ctx._workgraph[pk]["_executed_tasks"] + if label.split(".")[0] != name + ] + + def run_tasks( + self, pk: int, names: t.List[str], continue_workgraph: bool = True + ) -> None: + """Run tasks. + Task type includes: Node, Data, CalcFunction, WorkFunction, CalcJob, WorkChain, GraphBuilder, + WorkGraph, PythonJob, ShellJob, While, If, Zone, FromContext, ToContext, Normal. + + Here we use ToContext to pass the results of the run to the next step. + This will force the engine to wait for all the submitted processes to + finish before continuing to the next step. + """ + from aiida_workgraph.utils import ( + get_executor, + create_data_node, + update_nested_dict_with_special_keys, + ) + + for name in names: + # skip if the max number of awaitables is reached + task = self.ctx._workgraph[pk]["_tasks"][name] + if task["metadata"]["node_type"].upper() in [ + "CALCJOB", + "WORKCHAIN", + "GRAPH_BUILDER", + "WORKGRAPH", + "PYTHONJOB", + "SHELLJOB", + ]: + if len(self._awaitables) >= self.ctx._max_number_awaitables: + print( + MAX_NUMBER_AWAITABLES_MSG.format( + self.ctx._max_number_awaitables, name + ) + ) + continue + # skip if the task is already executed + if name in self.ctx._workgraph[pk]["_executed_tasks"]: + continue + self.ctx._workgraph[pk]["_executed_tasks"].append(name) + print("-" * 60) + + self.report(f"Run task: {name}, type: {task['metadata']['node_type']}") + executor, _ = get_executor(task["executor"]) + # print("executor: ", executor) + args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(name) + for i, key in enumerate( + self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"] + ): + kwargs[key] = args[i] + # update the port namespace + kwargs = update_nested_dict_with_special_keys(kwargs) + # print("args: ", args) + # print("kwargs: ", kwargs) + # print("var_kwargs: ", var_kwargs) + # kwargs["meta.label"] = name + # output must be a Data type or a mapping of {string: Data} + task["results"] = {} + if task["metadata"]["node_type"].upper() == "NODE": + results = self.run_executor(executor, [], kwargs, var_args, var_kwargs) + self.set_task_state_info(pk, name, "process", results) + self.update_task_state(name) + if continue_workgraph: + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() == "DATA": + for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]: + kwargs.pop(key, None) + results = create_data_node(executor, args, kwargs) + self.set_task_state_info(pk, name, "process", results) + self.update_task_state(name) + self.ctx._new_data[name] = results + if continue_workgraph: + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in [ + "CALCFUNCTION", + "WORKFUNCTION", + ]: + kwargs.setdefault("metadata", {}) + kwargs["metadata"].update({"call_link_label": name}) + try: + # since aiida 2.5.0, we need to use args_dict to pass the args to the run_get_node + if var_kwargs is None: + results, process = run_get_node(executor, **kwargs) + else: + results, process = run_get_node( + executor, **kwargs, **var_kwargs + ) + process.label = name + # print("results: ", results) + self.set_task_state_info(pk, name, "process", process) + self.update_task_state(name) + except Exception as e: + self.logger.error(f"Error in task {name}: {e}") + self.set_task_state_info(pk, name, "state", "FAILED") + # set child state to FAILED + self.set_tasks_state( + self.ctx._connectivity["child_node"][name], "SKIPPED" + ) + self.report(f"Task: {name} failed.") + # exclude the current tasks from the next run + if continue_workgraph: + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["CALCJOB", "WORKCHAIN"]: + # process = run_get_node(executor, *args, **kwargs) + kwargs.setdefault("metadata", {}) + kwargs["metadata"].update({"call_link_label": name}) + # transfer the args to kwargs + if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} is created and paused.") + process = create_and_pause_process( + self.runner, + executor, + kwargs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(pk, name, "state", "CREATED") + process = process.node + else: + process = self.submit(executor, **kwargs) + self.set_task_state_info(pk, name, "state", "RUNNING") + process.label = name + self.set_task_state_info(pk, name, "process", process) + self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]: + wg = self.run_executor(executor, [], kwargs, var_args, var_kwargs) + wg.name = name + wg.group_outputs = self.ctx._workgraph[pk]["_tasks"][name]["metadata"][ + "group_outputs" + ] + wg.parent_uuid = self.node.uuid + inputs = wg.prepare_inputs(metadata={"call_link_label": name}) + # process = self.submit(WorkGraphEngine, inputs=inputs) + self.set_task_state_info(pk, name, "process", process) + self.set_task_state_info(pk, name, "state", "RUNNING") + self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["WORKGRAPH"]: + from .utils import prepare_for_workgraph_task + + inputs, _ = prepare_for_workgraph_task(task, kwargs) + # process = self.submit(WorkGraphEngine, inputs=inputs) + self.set_task_state_info(pk, name, "process", process) + self.set_task_state_info(pk, name, "state", "RUNNING") + self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["PYTHONJOB"]: + from aiida_workgraph.calculations.python import PythonJob + from .utils import prepare_for_python_task + + inputs = prepare_for_python_task(task, kwargs, var_kwargs) + # since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs + if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} is created and paused.") + process = create_and_pause_process( + self.runner, + PythonJob, + inputs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(pk, name, "state", "CREATED") + process = process.node + else: + process = self.submit(PythonJob, **inputs) + self.set_task_state_info(pk, name, "state", "RUNNING") + process.label = name + self.set_task_state_info(pk, name, "process", process) + self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["SHELLJOB"]: + from aiida_shell.calculations.shell import ShellJob + from .utils import prepare_for_shell_task + + inputs = prepare_for_shell_task(task, kwargs) + if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} is created and paused.") + process = create_and_pause_process( + self.runner, + ShellJob, + inputs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(pk, name, "state", "CREATED") + process = process.node + else: + process = self.submit(ShellJob, **inputs) + self.set_task_state_info(pk, name, "state", "RUNNING") + process.label = name + self.set_task_state_info(pk, name, "process", process) + self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["WHILE"]: + # check the conditions of the while task + should_run = self.should_run_while_task(name) + if not should_run: + self.set_task_state_info(pk, name, "state", "FINISHED") + self.set_tasks_state( + self.ctx._workgraph[pk]["_tasks"][name]["children"], "SKIPPED" + ) + self.update_parent_task_state(pk, name) + self.report( + f"While Task {name}: Condition not fullilled, task finished. Skip all its children." + ) + else: + task["execution_count"] += 1 + self.set_task_state_info(pk, name, "state", "RUNNING") + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["IF"]: + should_run = self.should_run_if_task(name) + if should_run: + self.set_task_state_info(pk, name, "state", "RUNNING") + else: + self.set_tasks_state(task["children"], "SKIPPED") + self.update_zone_task_state(name) + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["ZONE"]: + self.set_task_state_info(pk, name, "state", "RUNNING") + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["FROM_CONTEXT"]: + # get the results from the context + results = {"result": getattr(self.ctx, kwargs["key"])} + task["results"] = results + self.set_task_state_info(pk, name, "state", "FINISHED") + self.update_parent_task_state(pk, name) + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["TO_CONTEXT"]: + # get the results from the context + setattr(self.ctx, kwargs["key"], kwargs["value"]) + self.set_task_state_info(pk, name, "state", "FINISHED") + self.update_parent_task_state(pk, name) + self.continue_workgraph(pk) + elif task["metadata"]["node_type"].upper() in ["AWAITABLE"]: + for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]: + kwargs.pop(key, None) + awaitable_target = asyncio.ensure_future( + self.run_executor(executor, args, kwargs, var_args, var_kwargs), + loop=self.loop, + ) + awaitable = self.construct_awaitable_function(name, awaitable_target) + self.set_task_state_info(pk, name, "state", "RUNNING") + self.to_context(**{name: awaitable}) + elif task["metadata"]["node_type"].upper() in ["MONITOR"]: + + for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]: + kwargs.pop(key, None) + # add function and interval to the args + args = [executor, kwargs.pop("interval"), kwargs.pop("timeout"), *args] + awaitable_target = asyncio.ensure_future( + self.run_executor(monitor, args, kwargs, var_args, var_kwargs), + loop=self.loop, + ) + awaitable = self.construct_awaitable_function(name, awaitable_target) + self.set_task_state_info(pk, name, "state", "RUNNING") + # save the awaitable to the temp, so that we can kill it if needed + self._temp["awaitables"][name] = awaitable_target + self.to_context(**{name: awaitable}) + elif task["metadata"]["node_type"].upper() in ["NORMAL"]: + # Normal task is created by decoratoring a function with @task() + if "context" in task["metadata"]["kwargs"]: + self.ctx.task_name = name + kwargs.update({"context": self.ctx}) + for key in self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"]: + kwargs.pop(key, None) + try: + results = self.run_executor( + executor, args, kwargs, var_args, var_kwargs + ) + self.set_normal_task_results(name, results) + except Exception as e: + self.logger.error(f"Error in task {name}: {e}") + self.set_task_state_info(pk, name, "state", "FAILED") + # set child tasks state to SKIPPED + self.set_tasks_state( + self.ctx._connectivity["child_node"][name], "SKIPPED" + ) + self.report(f"Task: {name} failed.") + self.run_error_handlers(pk, name) + if continue_workgraph: + self.continue_workgraph(pk) + else: + # self.report("Unknow task type {}".format(task["metadata"]["node_type"])) + return self.exit_codes.UNKNOWN_TASK_TYPE + + def construct_awaitable_function( + self, name: str, awaitable_target: Awaitable + ) -> None: + """Construct the awaitable function.""" + awaitable = Awaitable( + **{ + "pk": name, + "action": AwaitableAction.ASSIGN, + "target": "asyncio.tasks.Task", + "outputs": False, + } + ) + awaitable_target.key = name + awaitable_target.pk = name + awaitable_target.action = AwaitableAction.ASSIGN + awaitable_target.add_done_callback(self._on_awaitable_finished) + return awaitable + + def get_inputs( + self, name: str + ) -> t.Tuple[ + t.List[t.Any], + t.Dict[str, t.Any], + t.Optional[t.List[t.Any]], + t.Optional[t.Dict[str, t.Any]], + t.Dict[str, t.Any], + ]: + """Get input based on the links.""" + + args = [] + args_dict = {} + kwargs = {} + var_args = None + var_kwargs = None + task = self.ctx._workgraph[pk]["_tasks"][name] + properties = task.get("properties", {}) + inputs = {} + for input in task["inputs"]: + # print(f"input: {input['name']}") + if len(input["links"]) == 0: + inputs[input["name"]] = self.update_context_variable( + properties[input["name"]]["value"] + ) + elif len(input["links"]) == 1: + link = input["links"][0] + if ( + self.ctx._workgraph[pk]["_tasks"][link["from_node"]]["results"] + is None + ): + inputs[input["name"]] = None + else: + # handle the special socket _wait, _outputs + if link["from_socket"] == "_wait": + continue + elif link["from_socket"] == "_outputs": + inputs[input["name"]] = self.ctx._workgraph[pk]["_tasks"][ + link["from_node"] + ]["results"] + else: + inputs[input["name"]] = get_nested_dict( + self.ctx._workgraph[pk]["_tasks"][link["from_node"]][ + "results" + ], + link["from_socket"], + ) + # handle the case of multiple outputs + elif len(input["links"]) > 1: + value = {} + for link in input["links"]: + name = f'{link["from_node"]}_{link["from_socket"]}' + # handle the special socket _wait, _outputs + if link["from_socket"] == "_wait": + continue + if ( + self.ctx._workgraph[pk]["_tasks"][link["from_node"]]["results"] + is None + ): + value[name] = None + else: + value[name] = self.ctx._workgraph[pk]["_tasks"][ + link["from_node"] + ]["results"][link["from_socket"]] + inputs[input["name"]] = value + for name in task["metadata"].get("args", []): + if name in inputs: + args.append(inputs[name]) + args_dict[name] = inputs[name] + else: + value = self.update_context_variable(properties[name]["value"]) + args.append(value) + args_dict[name] = value + for name in task["metadata"].get("kwargs", []): + if name in inputs: + kwargs[name] = inputs[name] + else: + value = self.update_context_variable(properties[name]["value"]) + kwargs[name] = value + if task["metadata"]["var_args"] is not None: + name = task["metadata"]["var_args"] + if name in inputs: + var_args = inputs[name] + else: + value = self.update_context_variable(properties[name]["value"]) + var_args = value + if task["metadata"]["var_kwargs"] is not None: + name = task["metadata"]["var_kwargs"] + if name in inputs: + var_kwargs = inputs[name] + else: + value = self.update_context_variable(properties[name]["value"]) + var_kwargs = value + return args, kwargs, var_args, var_kwargs, args_dict + + def update_context_variable(self, value: t.Any) -> t.Any: + # replace context variables + + """Get value from context.""" + if isinstance(value, dict): + for key, sub_value in value.items(): + value[key] = self.update_context_variable(sub_value) + elif ( + isinstance(value, str) + and value.strip().startswith("{{") + and value.strip().endswith("}}") + ): + name = value[2:-2].strip() + return get_nested_dict(self.ctx, name) + return value + + def task_set_context(self, name: str) -> None: + """Export task result to context.""" + from aiida_workgraph.utils import update_nested_dict + + items = self.ctx._workgraph[pk]["_tasks"][name]["context_mapping"] + for key, value in items.items(): + result = self.ctx._workgraph[pk]["_tasks"][name]["results"][key] + update_nested_dict(self.ctx, value, result) + + def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: + """Check if the task ready to run. + For normal task and a zone task, we need to check its input tasks in the connectivity["zone"]. + For task inside a zone, we need to check if the zone (parent task) is ready. + """ + parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"] + # input_tasks, parent_task, conditions + parent_states = [True, True] + # if the task belongs to a parent zone + if parent_task[0]: + state = self.get_task_state_info(parent_task[0], "state") + if state not in ["RUNNING"]: + parent_states[1] = False + # check the input tasks of the zone + # check if the zone input tasks are ready + for child_task_name in self.ctx._connectivity["zone"][name]["input_tasks"]: + if self.get_task_state_info(child_task_name, "state") not in [ + "FINISHED", + "SKIPPED", + "FAILED", + ]: + parent_states[0] = False + break + + return all(parent_states), parent_states + + def reset(self) -> None: + self.ctx._execution_count += 1 + self.set_tasks_state(self.ctx._workgraph[pk]["_tasks"].keys(), "PLANNED") + self.ctx._workgraph[pk]["_executed_tasks"] = [] + + def set_tasks_state( + self, tasks: t.Union[t.List[str], t.Sequence[str]], value: str + ) -> None: + """Set tasks state""" + for name in tasks: + self.set_task_state_info(pk, name, "state", value) + if "children" in self.ctx._workgraph[pk]["_tasks"][name]: + self.set_tasks_state( + self.ctx._workgraph[pk]["_tasks"][name]["children"], value + ) + + def run_executor( + self, + executor: t.Callable, + args: t.List[t.Any], + kwargs: t.Dict[str, t.Any], + var_args: t.Optional[t.List[t.Any]], + var_kwargs: t.Optional[t.Dict[str, t.Any]], + ) -> t.Any: + if var_kwargs is None: + return executor(*args, **kwargs) + else: + return executor(*args, **kwargs, **var_kwargs) + + def save_results_to_extras(self, name: str) -> None: + """Save the results to the base.extras. + For the outputs of a Normal task, they are not saved to the database like the calcjob or workchain. + One temporary solution is to save the results to the base.extras. In order to do this, we need to + serialize the results + """ + from aiida_workgraph.utils import get_executor + + results = self.ctx._workgraph[pk]["_tasks"][name]["results"] + if results is None: + return + datas = {} + for key, value in results.items(): + # find outptus sockets with the name as key + output = [ + output + for output in self.ctx._workgraph[pk]["_tasks"][name]["outputs"] + if output["name"] == key + ] + if len(output) == 0: + continue + output = output[0] + Executor, _ = get_executor(output["serialize"]) + datas[key] = Executor(value) + self.node.set_extra(f"nodes__results__{name}", datas) + + def message_receive( + self, _comm: kiwipy.Communicator, msg: t.Dict[str, t.Any] + ) -> t.Any: + """ + Coroutine called when the process receives a message from the communicator + + :param _comm: the communicator that sent the message + :param msg: the message + :return: the outcome of processing the message, the return value will be sent back as a response to the sender + """ + self.logger.debug( + "Process<%s>: received RPC message with communicator '%s': %r", + self.pid, + _comm, + msg, + ) + + intent = msg[process_comms.INTENT_KEY] + + if intent == process_comms.Intent.PLAY: + return self._schedule_rpc(self.play) + if intent == process_comms.Intent.PAUSE: + return self._schedule_rpc( + self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None) + ) + if intent == process_comms.Intent.KILL: + return self._schedule_rpc( + self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None) + ) + if intent == process_comms.Intent.STATUS: + status_info: t.Dict[str, t.Any] = {} + self.get_status_info(status_info) + return status_info + if intent == "custom": + return self._schedule_rpc(self.apply_action, msg=msg) + + # Didn't match any known intents + raise RuntimeError("Unknown intent") + + def finalize(self) -> t.Optional[ExitCode]: + """""" + # expose outputs of the workgraph + group_outputs = {} + for output in self.ctx._workgraph["metadata"]["group_outputs"]: + names = output["from"].split(".", 1) + if names[0] == "context": + if len(names) == 1: + raise ValueError("The output name should be context.key") + update_nested_dict( + group_outputs, + output["name"], + get_nested_dict(self.ctx, names[1]), + ) + else: + # expose the whole outputs of the tasks + if len(names) == 1: + update_nested_dict( + group_outputs, + output["name"], + self.ctx._workgraph[pk]["_tasks"][names[0]]["results"], + ) + else: + # expose one output of the task + # note, the output may not exist + if ( + names[1] + in self.ctx._workgraph[pk]["_tasks"][names[0]]["results"] + ): + update_nested_dict( + group_outputs, + output["name"], + self.ctx._workgraph[pk]["_tasks"][names[0]]["results"][ + names[1] + ], + ) + self.out_many(group_outputs) + # output the new data + self.out("new_data", self.ctx._new_data) + self.out("execution_count", orm.Int(self.ctx._execution_count).store()) + self.report("Finalize workgraph.") + for _, task in self.ctx._workgraph[pk]["_tasks"].items(): + if self.get_task_state_info(pk, task["name"], "state") == "FAILED": + return self.exit_codes.TASK_FAILED diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 339435a0..21b132cf 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -134,7 +134,7 @@ def submit( self.save(metadata=metadata) if self.process.process_state.value.upper() not in ["CREATED"]: raise ValueError(f"Process {self.process.pk} has already been submitted.") - self.continue_process() + self.continue_process_in_scheduler() # as long as we submit the process, it is a new submission, we should set restart_process to None self.restart_process = None if wait: @@ -415,6 +415,12 @@ def continue_process(self): process_controller = get_manager().get_process_controller() process_controller.continue_process(self.pk) + def continue_process_in_scheduler(self, scheduler_pk: int = 122006): + """Ask the scheduler to pick up the process from the database and run it.""" + from aiida_workgraph.utils.control import create_task_action + + create_task_action(scheduler_pk, [self.pk], action="launch_workgraph") + def play(self): import os diff --git a/pyproject.toml b/pyproject.toml index b67eee61..134ca928 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points.'aiida.workflows'] "workgraph.engine" = "aiida_workgraph.engine.workgraph:WorkGraphEngine" +"workgraph.scheduler" = "aiida_workgraph.engine.workgraph:WorkGraphScheduler" [project.entry-points."aiida.data"] "workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData" From f7ed561161cd7f23c2776396e519a5d830109dc4 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 27 Aug 2024 00:35:17 +0200 Subject: [PATCH 02/14] First minimal working version --- aiida_workgraph/engine/launch.py | 217 +++++++++++++++++++ aiida_workgraph/engine/scheduler.py | 309 ++++++++++++++++++---------- aiida_workgraph/workgraph.py | 4 +- 3 files changed, 423 insertions(+), 107 deletions(-) create mode 100644 aiida_workgraph/engine/launch.py diff --git a/aiida_workgraph/engine/launch.py b/aiida_workgraph/engine/launch.py new file mode 100644 index 00000000..dfd3052c --- /dev/null +++ b/aiida_workgraph/engine/launch.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import time +import typing as t + +from aiida.common import InvalidOperation +from aiida.common.log import AIIDA_LOGGER +from aiida.manage import manager +from aiida.orm import ProcessNode + +from aiida.engine.processes.builder import ProcessBuilder +from aiida.engine.processes.functions import get_stack_size +from aiida.engine.processes.process import Process +from aiida.engine.utils import prepare_inputs, is_process_function + +import signal +import sys +import inspect +from typing import ( + Type, + Union, +) + + +from aiida.manage import get_manager + +__all__ = ("run_get_node", "submit") + +TYPE_RUN_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] +# run can also be process function, but it is not clear what type this should be +TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] +LOGGER = AIIDA_LOGGER.getChild("engine.launch") + + +def run_get_node( + process_class, *args, **kwargs +) -> tuple[dict[str, t.Any] | None, "ProcessNode"]: + """Run the FunctionProcess with the supplied inputs in a local runner. + :param args: input arguments to construct the FunctionProcess + :param kwargs: input keyword arguments to construct the FunctionProcess + :return: tuple of the outputs of the process and the process node + """ + parent_pid = kwargs.pop("_parent_pid", None) + frame_delta = 1000 + frame_count = get_stack_size() + stack_limit = sys.getrecursionlimit() + LOGGER.info( + "Executing process function, current stack status: %d frames of %d", + frame_count, + stack_limit, + ) + # If the current frame count is more than 80% of the stack limit, or comes within 200 frames, increase the + # stack limit by ``frame_delta``. + if frame_count > min(0.8 * stack_limit, stack_limit - 200): + LOGGER.warning( + "Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d", + frame_count, + stack_limit, + frame_delta, + ) + sys.setrecursionlimit(stack_limit + frame_delta) + manager = get_manager() + runner = manager.get_runner() + inputs = process_class.create_inputs(*args, **kwargs) + # Remove all the known inputs from the kwargs + for port in process_class.spec().inputs: + kwargs.pop(port, None) + # If any kwargs remain, the spec should be dynamic, so we raise if it isn't + if kwargs and not process_class.spec().inputs.dynamic: + raise ValueError( + f"{function.__name__} does not support these kwargs: {kwargs.keys()}" + ) + process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid) + # Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner. + # Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown + current_runner = manager.get_runner() + original_handler = None + kill_signal = signal.SIGINT + if not current_runner.is_daemon_runner: + + def kill_process(_num, _frame): + """Send the kill signal to the process in the current scope.""" + LOGGER.critical( + "runner received interrupt, killing process %s", process.pid + ) + result = process.kill( + msg="Process was killed because the runner received an interrupt" + ) + return result + + # Store the current handler on the signal such that it can be restored after process has terminated + original_handler = signal.getsignal(kill_signal) + signal.signal(kill_signal, kill_process) + try: + result = process.execute() + finally: + # If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset + if original_handler: + signal.signal(signal.SIGINT, original_handler) + store_provenance = inputs.get("metadata", {}).get("store_provenance", True) + if not store_provenance: + process.node._storable = False + process.node._unstorable_message = ( + "cannot store node because it was run with `store_provenance=False`" + ) + return result, process.node + + +def instantiate_process( + runner: "Runner", + process: Union["Process", Type["Process"], "ProcessBuilder"], + _parent_pid, + **inputs, +) -> "Process": + """Return an instance of the process with the given inputs. The function can deal with various types + of the `process`: + + * Process instance: will simply return the instance + * ProcessBuilder instance: will instantiate the Process from the class and inputs defined within it + * Process class: will instantiate with the specified inputs + + If anything else is passed, a ValueError will be raised + + :param process: Process instance or class, CalcJobNode class or ProcessBuilder instance + :param inputs: the inputs for the process to be instantiated with + """ + + if isinstance(process, Process): + assert not inputs + assert runner is process.runner + return process + + if isinstance(process, ProcessBuilder): + builder = process + process_class = builder.process_class + inputs.update(**builder._inputs(prune=True)) + elif is_process_function(process): + process_class = process.process_class # type: ignore[attr-defined] + elif inspect.isclass(process) and issubclass(process, Process): + process_class = process + else: + raise ValueError( + f"invalid process {type(process)}, needs to be Process or ProcessBuilder" + ) + + process = process_class(runner=runner, inputs=inputs, parent_pid=_parent_pid) + + return process + + +def submit( + process: TYPE_SUBMIT_PROCESS, + inputs: dict[str, t.Any] | None = None, + *, + wait: bool = False, + wait_interval: int = 5, + **kwargs: t.Any, +) -> ProcessNode: + """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. + + .. warning: this should not be used within another process. Instead, there one should use the ``submit`` method of + the wrapping process itself, i.e. use ``self.submit``. + + .. warning: submission of processes requires ``store_provenance=True``. + + :param process: the process class, instance or builder to submit + :param inputs: the input dictionary to be passed to the process + :param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which + point the function returns the calculation node. + :param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``. + :param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument. + :return: the calculation node of the process + """ + _parent_pid = kwargs.pop("_parent_pid", None) + inputs = prepare_inputs(inputs, **kwargs) + + # Submitting from within another process requires ``self.submit``` unless it is a work function, in which case the + # current process in the scope should be an instance of ``FunctionProcess``. + # if is_process_scoped() and not isinstance(Process.current(), FunctionProcess): + # raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead') + + runner = manager.get_manager().get_runner() + assert runner.persister is not None, "runner does not have a persister" + assert runner.controller is not None, "runner does not have a controller" + + process_inited = instantiate_process( + runner, process, _parent_pid=_parent_pid, **inputs + ) + + # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this + # instead of raising, because in this way the user does not have to change the launcher when testing. The same goes + # for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation. + if process_inited.metadata.get("dry_run", False) or "remote_folder" in inputs: + _, node = run_get_node(process_inited) + return node + + if not process_inited.metadata.store_provenance: + raise InvalidOperation("cannot submit a process with `store_provenance=False`") + + runner.persister.save_checkpoint(process_inited) + process_inited.close() + + # Do not wait for the future's result, because in the case of a single worker this would cock-block itself + runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) + node = process_inited.node + + if not wait: + return node + + while not node.is_terminated: + LOGGER.report( + f"Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. " + f"Waiting for {wait_interval} seconds." + ) + time.sleep(wait_interval) + + return node diff --git a/aiida_workgraph/engine/scheduler.py b/aiida_workgraph/engine/scheduler.py index 04201860..ba7aa798 100644 --- a/aiida_workgraph/engine/scheduler.py +++ b/aiida_workgraph/engine/scheduler.py @@ -9,7 +9,7 @@ from plumpy import process_comms from plumpy.persistence import auto_persist -from plumpy.process_states import Continue, Wait +from plumpy.process_states import Continue, Wait, Finished, Running import kiwipy from aiida.common import exceptions @@ -29,7 +29,6 @@ construct_awaitable, ) from aiida.engine.processes.workchains.workchain import Protect, WorkChainSpec -from aiida.engine import run_get_node from aiida_workgraph.utils import create_and_pause_process from aiida_workgraph.task import Task from aiida_workgraph.utils import get_nested_dict, update_nested_dict @@ -216,6 +215,8 @@ def _insert_awaitable(self, awaitable: Awaitable) -> None: else: raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") + # Register the callback to be called when the awaitable is resolved + self._add_callback_to_awaitable(awaitable) self._awaitables.append( awaitable ) # add only if everything went ok, otherwise we end up in an inconsistent state @@ -269,6 +270,7 @@ def to_context(self, **kwargs: Awaitable | ProcessNode) -> None: for key, value in kwargs.items(): awaitable = construct_awaitable(value) awaitable.key = key + awaitable.workgraph_pk = value.workgraph_pk self._insert_awaitable(awaitable) def _update_process_status(self) -> None: @@ -351,17 +353,22 @@ def _action_awaitables(self) -> None: # if the waitable already has a callback, skip if awaitable.pk in self.ctx._workgraph[pk]["_awaitable_actions"]: continue - if awaitable.target == AwaitableTarget.PROCESS: - callback = functools.partial( - self.call_soon, self._on_awaitable_finished, awaitable - ) - self.runner.call_on_process_finish(awaitable.pk, callback) - self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) - elif awaitable.target == "asyncio.tasks.Task": - # this is a awaitable task, the callback function is already set - self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) - else: - assert f"invalid awaitable target '{awaitable.target}'" + self._add_callback_to_awaitable(awaitable) + + def _add_callback_to_awaitable(self, awaitable: Awaitable) -> None: + """Add a callback to the awaitable.""" + pk = awaitable.workgraph_pk + if awaitable.target == AwaitableTarget.PROCESS: + callback = functools.partial( + self.call_soon, self._on_awaitable_finished, awaitable + ) + self.runner.call_on_process_finish(awaitable.pk, callback) + self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) + elif awaitable.target == "asyncio.tasks.Task": + # this is a awaitable task, the callback function is already set + self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) + else: + assert f"invalid awaitable target '{awaitable.target}'" def _on_awaitable_finished(self, awaitable: Awaitable) -> None: """Callback function, for when an awaitable process instance is completed. @@ -371,8 +378,12 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: :param awaitable: an Awaitable instance """ + print(f"Awaitable {awaitable.key} finished.") self.logger.debug(f"Awaitable {awaitable.key} finished.") pk = awaitable.workgraph_pk + node = load_node(awaitable.pk) + print("node: ", node) + print("state: ", node.process_state) if isinstance(awaitable.pk, int): self.logger.info( @@ -407,18 +418,28 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: self.set_task_state_info(pk, awaitable.key, "state", "KILLED") # set child tasks state to SKIPPED self.set_tasks_state( - self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][ + awaitable.key + ], + "SKIPPED", ) self.report(f"Task: {awaitable.key} cancelled.") else: results = awaitable.result() - self.set_normal_task_results(awaitable.key, results) + self.set_normal_task_results( + awaitable.workgraph_pk, awaitable.key, results + ) except Exception as e: self.logger.error(f"Error in awaitable {awaitable.key}: {e}") self.set_task_state_info(pk, awaitable.key, "state", "FAILED") # set child tasks state to SKIPPED self.set_tasks_state( - self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][ + awaitable.key + ], + "SKIPPED", ) self.report(f"Task: {awaitable.key} failed.") self.run_error_handlers(pk, awaitable.key) @@ -428,6 +449,7 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: # node finished, update the task state and result # udpate the task state + print(f"Update task state: {awaitable.key}") self.update_task_state(awaitable.workgraph_pk, awaitable.key) # try to resume the workgraph, if the workgraph is already resumed # by other awaitable, this will not work @@ -451,6 +473,7 @@ def setup(self) -> None: # track if the awaitable callback is added to the runner self.ctx._workgraph = {} + self.ctx._max_number_awaitables = 10000 awaitable = Awaitable( **{ "workgraph_pk": self.node.pk, @@ -463,29 +486,55 @@ def setup(self) -> None: self.ctx._workgraph[self.node.pk] = {"_awaitable_actions": []} self.to_context(scheduler=awaitable) # self.ctx._msgs = [] - # self.ctx._execution_count = {} + # self.ctx._workgraph[pk]["_execution_count"] = {} # data not to be persisted, because they are not serializable self._temp = {"awaitables": {}} + # self.launch_workgraph(122305) def launch_workgraph(self, pk: str) -> None: """Launch the workgraph.""" # create the workgraph process + self.report(f"Launch workgraph: {pk}") self.init_ctx_workgraph(pk) - self.set_task_results(pk) + self.ctx._workgraph[pk]["_node"].set_process_state(Running.LABEL) + self.init_task_results(pk) + self.continue_workgraph(pk) + + def init_ctx_workgraph(self, pk: int) -> None: + """Init the context from the workgraph data.""" + from aiida_workgraph.utils import update_nested_dict + + self.report(f"Init workgraph: {pk}") + # read the latest workgraph data + wgdata, node = self.read_wgdata_from_base(pk) + self.ctx._workgraph[pk] = { + "_awaitable_actions": {}, + "_new_data": {}, + "_execution_count": 1, + "_executed_tasks": [], + "_count": 0, + "_context": {}, + "_node": node, + } + for key, value in wgdata["context"].items(): + key = key.replace("__", ".") + update_nested_dict(self.ctx._workgraph[pk], key, value) + # set up the workgraph + self.setup_ctx_workgraph(pk, wgdata) - def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None: + def setup_ctx_workgraph(self, pk: int, wgdata: t.Dict[str, t.Any]) -> None: """setup the workgraph in the context.""" import cloudpickle as pickle - pk = wgdata["pk"] - self.ctx._workgraph[pk]["_tasks"] = wgdata["tasks"] - self.ctx._workgraph[pk]["_links"] = wgdata["links"] - self.ctx._workgraph[pk]["_connectivity"] = wgdata["connectivity"] - self.ctx._workgraph[pk]["_ctrl_links"] = wgdata["ctrl_links"] - self.ctx._workgraph[pk]["_workgraph"] = wgdata + self.report(f"Setup workgraph: {pk}") + self.ctx._workgraph[pk]["_tasks"] = wgdata.pop("tasks") + self.ctx._workgraph[pk]["_links"] = wgdata.pop("links") + self.ctx._workgraph[pk]["_connectivity"] = wgdata.pop("connectivity") + self.ctx._workgraph[pk]["_ctrl_links"] = wgdata.pop("ctrl_links") self.ctx._workgraph[pk]["_error_handlers"] = pickle.loads( - wgdata["error_handlers"] + wgdata.pop("error_handlers") ) + self.ctx._workgraph[pk]["_workgraph"] = wgdata self.ctx._workgraph[pk]["_awaitable_actions"] = [] def read_wgdata_from_base(self, pk: int) -> t.Dict[str, t.Any]: @@ -503,19 +552,19 @@ def read_wgdata_from_base(self, pk: int) -> t.Dict[str, t.Any]: wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) return wgdata, node - def update_workgraph_from_base(self) -> None: + def update_workgraph_from_base(self, pk: int) -> None: """Update the ctx from base.extras.""" wgdata, _ = self.read_wgdata_from_base() for name, task in wgdata["tasks"].items(): task["results"] = self.ctx._workgraph[pk]["_tasks"][name].get("results") - self.setup_ctx_workgraph(wgdata) + self.setup_ctx_workgraph(pk, wgdata) def get_task(self, name: str): """Get task from the context.""" task = Task.from_dict(self.ctx._workgraph[pk]["_tasks"][name]) return task - def update_task(self, task: Task): + def update_task(self, pk, task: Task): """Update task in the context. This is used in error handlers to update the task parameters.""" self.ctx._workgraph[pk]["_tasks"][task.name][ @@ -546,31 +595,13 @@ def set_task_state_info(self, pk: int, name: str, key: str, value: any) -> None: ) self.ctx._workgraph[pk]["_tasks"][name][key] = value - def init_ctx_workgraph(self, pk: int) -> None: - """Init the context from the workgraph data.""" - from aiida_workgraph.utils import update_nested_dict - - # read the latest workgraph data - wgdata, node = self.read_wgdata_from_base(pk) - self.ctx._workgraph[pk] = { - "_awaitable_actions": {}, - "_new_data": {}, - "_executed_tasks": {}, - "_count": 0, - "_context": {}, - "_node": node, - } - for key, value in wgdata["_context"].items(): - key = key.replace("__", ".") - update_nested_dict(self.ctx._workgraph[pk], key, value) - # set up the workgraph - self.setup_ctx_workgraph(wgdata) - - def set_task_results(self, pk) -> None: + def init_task_results(self, pk) -> None: + """Init the task results.""" for name, task in self.ctx._workgraph[pk]["_tasks"].items(): if self.get_task_state_info(pk, name, "action").upper() == "RESET": self.reset_task(pk, task["name"]) - self.update_task_state(pk, name) + # only init the task results, and do not need to continue the workgraph + self.update_task_state(pk, name, continue_workgraph=False) def apply_action(self, msg: dict) -> None: @@ -605,6 +636,7 @@ def apply_task_actions(self, msg: dict) -> None: def reset_task( self, + pk: int, name: str, reset_process: bool = True, recursive: bool = True, @@ -629,7 +661,7 @@ def reset_task( self.reset_task(child_task, reset_process=False, recursive=False) if recursive: # reset its child tasks - names = self.ctx._connectivity["child_node"][name] + names = self.ctx._workgraph[pk]["_connectivity"]["child_node"][name] for name in names: self.reset_task(name, recursive=False) @@ -667,7 +699,13 @@ def kill_task(self, pk, name: str) -> None: self.logger.error(f"Error in killing task {name}: {e}") def continue_workgraph(self, pk: int) -> None: - print("Continue workgraph.") + is_finished, _ = self.is_workgraph_finished(pk) + if is_finished: + self.report(f"Workgraph {pk} finished.") + self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL) + self.ctx._workgraph[pk]["_node"].set_exit_status(0) + self.ctx._workgraph[pk]["_node"].seal() + return self.report("Continue workgraph.") # self.update_workgraph_from_base() task_to_run = [] @@ -689,17 +727,28 @@ def continue_workgraph(self, pk: int) -> None: if ready: task_to_run.append(name) # - self.report("tasks ready to run: {}".format(",".join(task_to_run))) - self.run_tasks(pk, task_to_run) + self.report( + "tasks ready to run in WorkGraph {}, tasks: {}".format( + pk, ",".join(task_to_run) + ) + ) + if len(task_to_run) > 0: + self.run_tasks(pk, task_to_run) - def update_task_state(self, pk: int, name: str) -> None: + def update_task_state( + self, pk: int, name: str, continue_workgraph: bool = True + ) -> None: """Update task state when the task is finished.""" + + print("update task state: ", pk, name) task = self.ctx._workgraph[pk]["_tasks"][name] # print(f"set task result: {name}") node = self.get_task_state_info(pk, name, "process") + print("node", node) if isinstance(node, orm.ProcessNode): # print(f"set task result: {name} process") state = node.process_state.value.upper() + print("state", state) if node.is_finished_ok: self.set_task_state_info(pk, task["name"], "state", state) if task["metadata"]["node_type"].upper() == "WORKGRAPH": @@ -719,12 +768,15 @@ def update_task_state(self, pk: int, name: str) -> None: self.report(f"Workgraph: {pk}, Task: {name} finished.") # all other states are considered as failed else: + print(f"set task result: {name} failed") task["results"] = node.outputs # self.ctx._new_data[name] = task["results"] self.set_task_state_info(pk, task["name"], "state", "FAILED") # set child tasks state to SKIPPED self.set_tasks_state( - self.ctx._workgraph["_connectivity"]["child_node"][name], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", ) self.report(f"Workgraph: {pk}, Task: {name} failed.") self.run_error_handlers(pk, name) @@ -737,6 +789,8 @@ def update_task_state(self, pk: int, name: str) -> None: task.setdefault("results", None) self.update_parent_task_state(pk, name) + if continue_workgraph: + self.continue_workgraph(pk) def set_normal_task_results(self, pk, name, results): """Set the results of a normal task. @@ -757,7 +811,7 @@ def set_normal_task_results(self, pk, name, results): self.report(f"Workgraph: {pk}, Task: {name} finished.") self.update_parent_task_state(pk, name) - def update_parent_task_state(self, name: str) -> None: + def update_parent_task_state(self, pk, name: str) -> None: """Update parent task state.""" parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"] if parent_task[0]: @@ -796,7 +850,7 @@ def update_zone_task_state(self, name: str) -> None: self.update_parent_task_state(pk, name) self.report(f"Task: {name} finished.") - def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: + def should_run_while_task(self, pk: int, name: str) -> tuple[bool, t.Any]: """Check if the while task should run.""" # check the conditions of the while task not_excess_max_iterations = ( @@ -806,7 +860,7 @@ def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: ] ) conditions = [not_excess_max_iterations] - _, kwargs, _, _, _ = self.get_inputs(name) + _, kwargs, _, _, _ = self.get_inputs(pk, name) if isinstance(kwargs["conditions"], list): for condition in kwargs["conditions"]: value = get_nested_dict(self.ctx, condition) @@ -820,7 +874,7 @@ def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: def should_run_if_task(self, name: str) -> tuple[bool, t.Any]: """Check if the IF task should run.""" - _, kwargs, _, _, _ = self.get_inputs(name) + _, kwargs, _, _, _ = self.get_inputs(pk, name) flag = kwargs["conditions"] if kwargs["invert_condition"]: return not flag @@ -840,12 +894,12 @@ def are_childen_finished(self, pk, name: str) -> tuple[bool, t.Any]: break return finished, None - def run_error_handlers(self, pk, task_name: str) -> None: + def run_error_handlers(self, pk: int, task_name: str) -> None: """Run error handler.""" - node = self.get_task_state_info(task_name, "process") + node = self.get_task_state_info(pk, task_name, "process") if not node or not node.exit_status: return - for _, data in self.ctx._error_handlers.items(): + for _, data in self.ctx._workgraph[pk]["_error_handlers"].items(): if task_name in data["tasks"]: handler = data["handler"] metadata = data["tasks"][task_name] @@ -862,7 +916,7 @@ def is_workgraph_finished(self, pk) -> bool: is_finished = True failed_tasks = [] for name, task in self.ctx._workgraph[pk]["_tasks"].items(): - # self.update_task_state(name) + # self.update_task_state(pk, name) if self.get_task_state_info(pk, task["name"], "state") in [ "RUNNING", "CREATED", @@ -873,10 +927,13 @@ def is_workgraph_finished(self, pk) -> bool: elif self.get_task_state_info(pk, task["name"], "state") == "FAILED": failed_tasks.append(name) if is_finished: - if self.ctx._workgraph["workgraph_type"].upper() == "WHILE": + if ( + self.ctx._workgraph[pk]["_workgraph"]["workgraph_type"].upper() + == "WHILE" + ): should_run = self.check_while_conditions(pk) is_finished = not should_run - if self.ctx._workgraph["workgraph_type"].upper() == "FOR": + if self.ctx._workgraph[pk]["_workgraph"]["workgraph_type"].upper() == "FOR": should_run = self.check_for_conditions(pk) is_finished = not should_run if is_finished and len(failed_tasks) > 0: @@ -892,17 +949,20 @@ def check_while_conditions(self, pk: int) -> bool: Run all condition tasks and check if all the conditions are True. """ self.report("Check while conditions.") - if self.ctx._execution_count >= self.ctx._max_iteration: + if ( + self.ctx._workgraph[pk]["_execution_count"] + >= self.ctx._workgraph[pk]["_max_iteration"] + ): self.report("Max iteration reached.") return False condition_tasks = [] - for c in self.ctx._workgraph["conditions"]: + for c in self.ctx._workgraph[pk]["conditions"]: task_name, socket_name = c.split(".") if "task_name" != "context": condition_tasks.append(task_name) self.run_tasks(condition_tasks, continue_workgraph=False) conditions = [] - for c in self.ctx._workgraph["conditions"]: + for c in self.ctx._workgraph[pk]["conditions"]: task_name, socket_name = c.split(".") if task_name == "context": conditions.append(self.ctx[socket_name]) @@ -912,21 +972,21 @@ def check_while_conditions(self, pk: int) -> bool: ) should_run = False not in conditions if should_run: - self.reset() - self.set_tasks_state(condition_tasks, "SKIPPED") + self.reset_workgraph(pk) + self.set_tasks_state(pk, condition_tasks, "SKIPPED") return should_run def check_for_conditions(self, pk: int) -> bool: - condition_tasks = [c[0] for c in self.ctx._workgraph["conditions"]] + condition_tasks = [c[0] for c in self.ctx._workgraph[pk]["conditions"]] self.run_tasks(condition_tasks) conditions = [self.ctx._count < len(self.ctx._sequence)] + [ self.ctx._workgraph[pk]["_tasks"][c[0]]["results"][c[1]] - for c in self.ctx._workgraph["conditions"] + for c in self.ctx._workgraph[pk]["conditions"] ] should_run = False not in conditions if should_run: - self.reset() - self.set_tasks_state(condition_tasks, "SKIPPED") + self.reset_workgraph(pk) + self.set_tasks_state(pk, condition_tasks, "SKIPPED") self.ctx["i"] = self.ctx._sequence[self.ctx._count] self.ctx._count += 1 return should_run @@ -939,6 +999,19 @@ def remove_executed_task(self, pk, name: str) -> None: if label.split(".")[0] != name ] + def add_task_link(self, pk, node: ProcessNode) -> None: + from aiida.common.links import LinkType + + parent_calc = self.ctx._workgraph[pk]["_node"] + if isinstance(node, orm.CalculationNode): + node.base.links.add_incoming( + parent_calc, LinkType.CALL_CALC, "CALL" + ) # TODO, self.metadata.call_link_label) + elif isinstance(node, orm.WorkflowNode): + node.base.links.add_incoming( + parent_calc, LinkType.CALL_WORK, "CALL" + ) # TODO, self.metadata.call_link_label) + def run_tasks( self, pk: int, names: t.List[str], continue_workgraph: bool = True ) -> None: @@ -955,6 +1028,8 @@ def run_tasks( create_data_node, update_nested_dict_with_special_keys, ) + from aiida_workgraph.engine.workgraph import WorkGraphEngine + from aiida_workgraph.engine import launch for name in names: # skip if the max number of awaitables is reached @@ -983,7 +1058,7 @@ def run_tasks( self.report(f"Run task: {name}, type: {task['metadata']['node_type']}") executor, _ = get_executor(task["executor"]) # print("executor: ", executor) - args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(name) + args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(pk, name) for i, key in enumerate( self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"] ): @@ -999,7 +1074,7 @@ def run_tasks( if task["metadata"]["node_type"].upper() == "NODE": results = self.run_executor(executor, [], kwargs, var_args, var_kwargs) self.set_task_state_info(pk, name, "process", results) - self.update_task_state(name) + self.update_task_state(pk, name) if continue_workgraph: self.continue_workgraph(pk) elif task["metadata"]["node_type"].upper() == "DATA": @@ -1007,7 +1082,7 @@ def run_tasks( kwargs.pop(key, None) results = create_data_node(executor, args, kwargs) self.set_task_state_info(pk, name, "process", results) - self.update_task_state(name) + self.update_task_state(pk, name) self.ctx._new_data[name] = results if continue_workgraph: self.continue_workgraph(pk) @@ -1017,24 +1092,29 @@ def run_tasks( ]: kwargs.setdefault("metadata", {}) kwargs["metadata"].update({"call_link_label": name}) + kwargs["_parent_pid"] = pk try: # since aiida 2.5.0, we need to use args_dict to pass the args to the run_get_node if var_kwargs is None: - results, process = run_get_node(executor, **kwargs) + results, process = launch.run_get_node( + executor.process_class, **kwargs + ) else: - results, process = run_get_node( - executor, **kwargs, **var_kwargs + results, process = launch.run_get_node( + executor.process_class, **kwargs, **var_kwargs ) process.label = name # print("results: ", results) self.set_task_state_info(pk, name, "process", process) - self.update_task_state(name) + self.update_task_state(pk, name) except Exception as e: self.logger.error(f"Error in task {name}: {e}") self.set_task_state_info(pk, name, "state", "FAILED") # set child state to FAILED self.set_tasks_state( - self.ctx._connectivity["child_node"][name], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", ) self.report(f"Task: {name} failed.") # exclude the current tasks from the next run @@ -1044,6 +1124,7 @@ def run_tasks( # process = run_get_node(executor, *args, **kwargs) kwargs.setdefault("metadata", {}) kwargs["metadata"].update({"call_link_label": name}) + kwargs["_parent_pid"] = pk # transfer the args to kwargs if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") @@ -1057,9 +1138,10 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "CREATED") process = process.node else: - process = self.submit(executor, **kwargs) + process = launch.submit(executor, **kwargs) self.set_task_state_info(pk, name, "state", "RUNNING") process.label = name + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]: @@ -1070,7 +1152,9 @@ def run_tasks( ] wg.parent_uuid = self.node.uuid inputs = wg.prepare_inputs(metadata={"call_link_label": name}) - # process = self.submit(WorkGraphEngine, inputs=inputs) + inputs["parent_pid"] = pk + process = launch.submit(WorkGraphEngine, inputs=inputs) + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.set_task_state_info(pk, name, "state", "RUNNING") self.to_context(**{name: process}) @@ -1078,7 +1162,9 @@ def run_tasks( from .utils import prepare_for_workgraph_task inputs, _ = prepare_for_workgraph_task(task, kwargs) - # process = self.submit(WorkGraphEngine, inputs=inputs) + inputs["parent_pid"] = pk + process = launch.submit(WorkGraphEngine, inputs=inputs) + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.set_task_state_info(pk, name, "state", "RUNNING") self.to_context(**{name: process}) @@ -1087,6 +1173,7 @@ def run_tasks( from .utils import prepare_for_python_task inputs = prepare_for_python_task(task, kwargs, var_kwargs) + inputs["parent_pid"] = pk # since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") @@ -1100,9 +1187,10 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "CREATED") process = process.node else: - process = self.submit(PythonJob, **inputs) + process = launch.submit(PythonJob, **inputs) self.set_task_state_info(pk, name, "state", "RUNNING") process.label = name + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["SHELLJOB"]: @@ -1110,6 +1198,7 @@ def run_tasks( from .utils import prepare_for_shell_task inputs = prepare_for_shell_task(task, kwargs) + inputs["parent_pid"] = pk if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") self.report(f"Task {name} is created and paused.") @@ -1122,9 +1211,10 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "CREATED") process = process.node else: - process = self.submit(ShellJob, **inputs) + process = launch.submit(ShellJob, **inputs) self.set_task_state_info(pk, name, "state", "RUNNING") process.label = name + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["WHILE"]: @@ -1133,7 +1223,9 @@ def run_tasks( if not should_run: self.set_task_state_info(pk, name, "state", "FINISHED") self.set_tasks_state( - self.ctx._workgraph[pk]["_tasks"][name]["children"], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_tasks"][name]["children"], + "SKIPPED", ) self.update_parent_task_state(pk, name) self.report( @@ -1148,7 +1240,7 @@ def run_tasks( if should_run: self.set_task_state_info(pk, name, "state", "RUNNING") else: - self.set_tasks_state(task["children"], "SKIPPED") + self.set_tasks_state(pk, task["children"], "SKIPPED") self.update_zone_task_state(name) self.continue_workgraph(pk) elif task["metadata"]["node_type"].upper() in ["ZONE"]: @@ -1203,13 +1295,15 @@ def run_tasks( results = self.run_executor( executor, args, kwargs, var_args, var_kwargs ) - self.set_normal_task_results(name, results) + self.set_normal_task_results(pk, name, results) except Exception as e: self.logger.error(f"Error in task {name}: {e}") self.set_task_state_info(pk, name, "state", "FAILED") # set child tasks state to SKIPPED self.set_tasks_state( - self.ctx._connectivity["child_node"][name], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", ) self.report(f"Task: {name} failed.") self.run_error_handlers(pk, name) @@ -1238,7 +1332,7 @@ def construct_awaitable_function( return awaitable def get_inputs( - self, name: str + self, pk: int, name: str ) -> t.Tuple[ t.List[t.Any], t.Dict[str, t.Any], @@ -1348,7 +1442,7 @@ def update_context_variable(self, value: t.Any) -> t.Any: return get_nested_dict(self.ctx, name) return value - def task_set_context(self, name: str) -> None: + def task_set_context(self, pk, name: str) -> None: """Export task result to context.""" from aiida_workgraph.utils import update_nested_dict @@ -1357,7 +1451,7 @@ def task_set_context(self, name: str) -> None: result = self.ctx._workgraph[pk]["_tasks"][name]["results"][key] update_nested_dict(self.ctx, value, result) - def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: + def is_task_ready_to_run(self, pk, name: str) -> t.Tuple[bool, t.Optional[str]]: """Check if the task ready to run. For normal task and a zone task, we need to check its input tasks in the connectivity["zone"]. For task inside a zone, we need to check if the zone (parent task) is ready. @@ -1367,13 +1461,15 @@ def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: parent_states = [True, True] # if the task belongs to a parent zone if parent_task[0]: - state = self.get_task_state_info(parent_task[0], "state") + state = self.get_task_state_info(pk, parent_task[0], "state") if state not in ["RUNNING"]: parent_states[1] = False # check the input tasks of the zone # check if the zone input tasks are ready - for child_task_name in self.ctx._connectivity["zone"][name]["input_tasks"]: - if self.get_task_state_info(child_task_name, "state") not in [ + for child_task_name in self.ctx._workgraph[pk]["_connectivity"]["zone"][name][ + "input_tasks" + ]: + if self.get_task_state_info(pk, child_task_name, "state") not in [ "FINISHED", "SKIPPED", "FAILED", @@ -1383,20 +1479,20 @@ def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: return all(parent_states), parent_states - def reset(self) -> None: - self.ctx._execution_count += 1 - self.set_tasks_state(self.ctx._workgraph[pk]["_tasks"].keys(), "PLANNED") + def reset_workgraph(self, pk) -> None: + self.ctx._workgraph[pk]["_execution_count"] += 1 + self.set_tasks_state(pk, self.ctx._workgraph[pk]["_tasks"].keys(), "PLANNED") self.ctx._workgraph[pk]["_executed_tasks"] = [] def set_tasks_state( - self, tasks: t.Union[t.List[str], t.Sequence[str]], value: str + self, pk: int, tasks: t.Union[t.List[str], t.Sequence[str]], value: str ) -> None: """Set tasks state""" for name in tasks: self.set_task_state_info(pk, name, "state", value) if "children" in self.ctx._workgraph[pk]["_tasks"][name]: self.set_tasks_state( - self.ctx._workgraph[pk]["_tasks"][name]["children"], value + pk, self.ctx._workgraph[pk]["_tasks"][name]["children"], value ) def run_executor( @@ -1516,7 +1612,10 @@ def finalize(self) -> t.Optional[ExitCode]: self.out_many(group_outputs) # output the new data self.out("new_data", self.ctx._new_data) - self.out("execution_count", orm.Int(self.ctx._execution_count).store()) + self.out( + "execution_count", + orm.Int(self.ctx._workgraph[pk]["_execution_count"]).store(), + ) self.report("Finalize workgraph.") for _, task in self.ctx._workgraph[pk]["_tasks"].items(): if self.get_task_state_info(pk, task["name"], "state") == "FAILED": diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 21b132cf..88bcce5f 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -134,7 +134,7 @@ def submit( self.save(metadata=metadata) if self.process.process_state.value.upper() not in ["CREATED"]: raise ValueError(f"Process {self.process.pk} has already been submitted.") - self.continue_process_in_scheduler() + self.continue_process() # as long as we submit the process, it is a new submission, we should set restart_process to None self.restart_process = None if wait: @@ -415,7 +415,7 @@ def continue_process(self): process_controller = get_manager().get_process_controller() process_controller.continue_process(self.pk) - def continue_process_in_scheduler(self, scheduler_pk: int = 122006): + def continue_process_in_scheduler(self, scheduler_pk: int = 122744): """Ask the scheduler to pick up the process from the database and run it.""" from aiida_workgraph.utils.control import create_task_action From ed47f0b6391bf3679f4d4f310530bedbacbba6ad Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 27 Aug 2024 10:01:14 +0200 Subject: [PATCH 03/14] add scheduler cli, query scheduler automatically --- aiida_workgraph/cli/__init__.py | 3 +- aiida_workgraph/cli/cmd_scheduler.py | 125 +++++++++++++++++++++++++++ aiida_workgraph/engine/utils.py | 20 +++++ aiida_workgraph/workgraph.py | 9 +- pyproject.toml | 2 +- 5 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 aiida_workgraph/cli/cmd_scheduler.py diff --git a/aiida_workgraph/cli/__init__.py b/aiida_workgraph/cli/__init__.py index db1f1f89..437e58d7 100644 --- a/aiida_workgraph/cli/__init__.py +++ b/aiida_workgraph/cli/__init__.py @@ -5,6 +5,7 @@ from aiida_workgraph.cli import cmd_graph from aiida_workgraph.cli import cmd_web from aiida_workgraph.cli import cmd_task +from aiida_workgraph.cli import cmd_scheduler -__all__ = ["cmd_graph", "cmd_web", "cmd_task"] +__all__ = ["cmd_graph", "cmd_web", "cmd_task", "cmd_scheduler"] diff --git a/aiida_workgraph/cli/cmd_scheduler.py b/aiida_workgraph/cli/cmd_scheduler.py new file mode 100644 index 00000000..2f00e8d5 --- /dev/null +++ b/aiida_workgraph/cli/cmd_scheduler.py @@ -0,0 +1,125 @@ +from aiida_workgraph.cli.cmd_workgraph import workgraph +from aiida import orm +import click +import os +from pathlib import Path +from aiida.cmdline.utils import echo +from .cmd_graph import REPAIR_INSTRUCTIONS + + +REACT_PORT = "3000" + + +def get_package_root(): + """Returns the root directory of the package.""" + current_file = Path(__file__) + # Root directory of your package + return current_file.parent + + +def get_pid_file_path(): + """Get the path to the PID file in the desired directory.""" + from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER + + return AIIDA_CONFIG_FOLDER / "scheduler_processes.pid" + + +@workgraph.group("scheduler") +def scheduler(): + """Commands to manage the scheduler process.""" + + +@scheduler.command() +def start(): + """Start the scheduler application.""" + from aiida_workgraph.engine.scheduler import WorkGraphScheduler + from aiida.engine import submit + + click.echo("Starting the scheduler process...") + + pid_file_path = get_pid_file_path() + # if the PID file already exists, check if the process is running + if pid_file_path.exists(): + with open(pid_file_path, "r") as pid_file: + for line in pid_file: + _, pid = line.strip().split(":") + if pid: + try: + node = orm.load_node(pid) + if node.is_sealed: + click.echo( + "PID file exists but no running scheduler process found." + ) + else: + click.echo( + f"Scheduler process with PID {node.pk} already running." + ) + return + except Exception: + click.echo( + "PID file exists but no running scheduler process found." + ) + + with open(pid_file_path, "w") as pid_file: + node = submit(WorkGraphScheduler) + pid_file.write(f"Scheduler:{node.pk}\n") + click.echo(f"Scheduler process started with PID {node.pk}.") + + +@scheduler.command() +def stop(): + """Stop the scheduler application.""" + from aiida.engine.processes import control + + pid_file_path = get_pid_file_path() + + if not pid_file_path.exists(): + click.echo("No running scheduler application found.") + return + + with open(pid_file_path, "r") as pid_file: + for line in pid_file: + _, pid = line.strip().split(":") + if pid: + click.confirm( + "Are you sure you want to kill the scheduler process?", abort=True + ) + process = orm.load_node(pid) + try: + message = "Killed through `verdi process kill`" + control.kill_processes( + [process], + timeout=5, + wait=True, + message=message, + ) + except control.ProcessTimeoutException as exception: + echo.echo_critical(f"{exception}\n{REPAIR_INSTRUCTIONS}") + os.remove(pid_file_path) + + +@scheduler.command() +def status(): + """Check the status of the scheduler application.""" + from aiida.orm import QueryBuilder + from aiida_workgraph.engine.scheduler import WorkGraphScheduler + + qb = QueryBuilder() + projections = ["id"] + filters = { + "or": [ + {"attributes.sealed": False}, + {"attributes": {"!has_key": "sealed"}}, + ] + } + qb.append( + WorkGraphScheduler, + filters=filters, + project=projections, + tag="process", + ) + results = qb.all() + if len(results) == 0: + click.echo("No scheduler found. Please start the scheduler first.") + else: + click.echo(f"Scheduler process is running with PID: {results[0][0]}") diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index 82332a78..e35e10d4 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -139,3 +139,23 @@ def prepare_for_shell_task(task: dict, kwargs: dict) -> dict: "metadata": metadata or {}, } return inputs + + +def get_scheduler(): + from aiida.orm import QueryBuilder + from aiida_workgraph.engine.scheduler import WorkGraphScheduler + + qb = QueryBuilder() + projections = ["id"] + filters = { + "or": [ + {"attributes.sealed": False}, + {"attributes": {"!has_key": "sealed"}}, + ] + } + qb.append(WorkGraphScheduler, filters=filters, project=projections, tag="process") + results = qb.all() + if len(results) == 0: + raise ValueError("No scheduler found. Please start the scheduler first.") + scheduler_id = results[0][0] + return scheduler_id diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 88bcce5f..7b60ab60 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -115,6 +115,7 @@ def submit( wait: bool = False, timeout: int = 60, metadata: Optional[Dict[str, Any]] = None, + to_scheduler: bool = False, ) -> aiida.orm.ProcessNode: """Submit the AiiDA workgraph process and optionally wait for it to finish. Args: @@ -123,6 +124,8 @@ def submit( restart (bool): Restart the process, and reset the modified tasks, then only re-run the modified tasks. new (bool): Submit a new process. """ + from aiida_workgraph.engine.utils import get_scheduler + # set task inputs if inputs is not None: for name, input in inputs.items(): @@ -134,7 +137,11 @@ def submit( self.save(metadata=metadata) if self.process.process_state.value.upper() not in ["CREATED"]: raise ValueError(f"Process {self.process.pk} has already been submitted.") - self.continue_process() + if to_scheduler: + scheduler_pk = get_scheduler() + self.continue_process_in_scheduler(scheduler_pk) + else: + self.continue_process() # as long as we submit the process, it is a new submission, we should set restart_process to None self.restart_process = None if wait: diff --git a/pyproject.toml b/pyproject.toml index 134ca928..a13f5e48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points.'aiida.workflows'] "workgraph.engine" = "aiida_workgraph.engine.workgraph:WorkGraphEngine" -"workgraph.scheduler" = "aiida_workgraph.engine.workgraph:WorkGraphScheduler" +"workgraph.scheduler" = "aiida_workgraph.engine.scheduler:WorkGraphScheduler" [project.entry-points."aiida.data"] "workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData" From 8d90643554b1f0bedf9ab554ba1cd9a5a57ea916 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 27 Aug 2024 15:56:10 +0200 Subject: [PATCH 04/14] add start_worker for scheduler command --- aiida_workgraph/cli/cmd_scheduler.py | 10 ++++ aiida_workgraph/engine/launch.py | 72 ++++++++++++++++++++++++++-- aiida_workgraph/engine/scheduler.py | 3 +- 3 files changed, 80 insertions(+), 5 deletions(-) diff --git a/aiida_workgraph/cli/cmd_scheduler.py b/aiida_workgraph/cli/cmd_scheduler.py index 2f00e8d5..d265012e 100644 --- a/aiida_workgraph/cli/cmd_scheduler.py +++ b/aiida_workgraph/cli/cmd_scheduler.py @@ -29,6 +29,16 @@ def scheduler(): """Commands to manage the scheduler process.""" +@scheduler.command() +def start_worker(): + """Start the scheduler application.""" + from aiida_workgraph.engine.launch import start_scheduler_worker + + click.echo("Starting the scheduler worker...") + + start_scheduler_worker() + + @scheduler.command() def start(): """Start the scheduler application.""" diff --git a/aiida_workgraph/engine/launch.py b/aiida_workgraph/engine/launch.py index dfd3052c..02f21499 100644 --- a/aiida_workgraph/engine/launch.py +++ b/aiida_workgraph/engine/launch.py @@ -70,7 +70,10 @@ def run_get_node( raise ValueError( f"{function.__name__} does not support these kwargs: {kwargs.keys()}" ) - process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid) + if parent_pid: + process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid) + else: + process = process_class(inputs=inputs, runner=runner) # Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner. # Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown current_runner = manager.get_runner() @@ -109,7 +112,7 @@ def kill_process(_num, _frame): def instantiate_process( runner: "Runner", process: Union["Process", Type["Process"], "ProcessBuilder"], - _parent_pid, + _parent_pid=None, **inputs, ) -> "Process": """Return an instance of the process with the given inputs. The function can deal with various types @@ -143,7 +146,10 @@ def instantiate_process( f"invalid process {type(process)}, needs to be Process or ProcessBuilder" ) - process = process_class(runner=runner, inputs=inputs, parent_pid=_parent_pid) + if _parent_pid: + process = process_class(runner=runner, inputs=inputs, parent_pid=_parent_pid) + else: + process = process_class(runner=runner, inputs=inputs) return process @@ -172,6 +178,7 @@ def submit( :return: the calculation node of the process """ _parent_pid = kwargs.pop("_parent_pid", None) + runner = kwargs.pop("runner", None) inputs = prepare_inputs(inputs, **kwargs) # Submitting from within another process requires ``self.submit``` unless it is a work function, in which case the @@ -179,7 +186,8 @@ def submit( # if is_process_scoped() and not isinstance(Process.current(), FunctionProcess): # raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead') - runner = manager.get_manager().get_runner() + if not runner: + runner = manager.get_manager().get_runner() assert runner.persister is not None, "runner does not have a persister" assert runner.controller is not None, "runner does not have a controller" @@ -215,3 +223,59 @@ def submit( time.sleep(wait_interval) return node + + +def start_scheduler_worker(foreground: bool = False) -> None: + """Start a scheduler worker for the currently configured profile. + + :param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to + write to the scheduler log file. + """ + import asyncio + import signal + import sys + + from aiida.common.log import configure_logging + from aiida.engine.daemon.client import get_daemon_client + from aiida.manage import get_config_option, get_manager + from aiida_workgraph.engine.scheduler import WorkGraphScheduler + + daemon_client = get_daemon_client() + configure_logging( + daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file + ) + + LOGGER.debug(f"sys.executable: {sys.executable}") + LOGGER.debug(f"sys.path: {sys.path}") + + try: + manager = get_manager() + # runner = manager.create_daemon_runner() + runner = manager.create_runner(broker_submit=True) + manager.set_runner(runner) + except Exception: + LOGGER.exception("daemon worker failed to start") + raise + + if isinstance(rlimit := get_config_option("daemon.recursion_limit"), int): + LOGGER.info("Setting maximum recursion limit of daemon worker to %s", rlimit) + sys.setrecursionlimit(rlimit) + + signals = (signal.SIGTERM, signal.SIGINT) + for s in signals: + # https://github.com/python/mypy/issues/12557 + runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) # type: ignore[misc] + + print("runner", runner) + process_inited = instantiate_process(runner, WorkGraphScheduler) + runner.loop.create_task(process_inited.step_until_terminated()) + print("node", process_inited.node) + + try: + LOGGER.info("Starting a daemon worker") + runner.start() + except SystemError as exception: + LOGGER.info("Received a SystemError: %s", exception) + runner.close() + + LOGGER.info("Daemon worker started") diff --git a/aiida_workgraph/engine/scheduler.py b/aiida_workgraph/engine/scheduler.py index ba7aa798..fd854819 100644 --- a/aiida_workgraph/engine/scheduler.py +++ b/aiida_workgraph/engine/scheduler.py @@ -454,9 +454,10 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: # try to resume the workgraph, if the workgraph is already resumed # by other awaitable, this will not work try: + print("Resume scheduler.") self.resume() except Exception as e: - print(e) + print("Error in resume: ", e) def _build_process_label(self) -> str: """Use the workgraph name as the process label.""" From 86696753784b2311f981d09d90cbbcee8c92a27c Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 27 Aug 2024 17:43:01 +0200 Subject: [PATCH 05/14] wait for awaitable node's process_state to be ready, do not resume --- aiida_workgraph/engine/scheduler.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/aiida_workgraph/engine/scheduler.py b/aiida_workgraph/engine/scheduler.py index fd854819..3c29cec6 100644 --- a/aiida_workgraph/engine/scheduler.py +++ b/aiida_workgraph/engine/scheduler.py @@ -378,12 +378,19 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: :param awaitable: an Awaitable instance """ + import time + print(f"Awaitable {awaitable.key} finished.") self.logger.debug(f"Awaitable {awaitable.key} finished.") pk = awaitable.workgraph_pk node = load_node(awaitable.pk) - print("node: ", node) - print("state: ", node.process_state) + # there is a bug in aiida.engine.process, it send the msg before setting the process state and outputs + # so we need to wait for a while + # TODO make a PR to fix the aiida-core, the `super().on_entered()` should be + # called after setting the process state + tstart = time.time() + while not node.is_finished and time.time() - tstart < 5: + time.sleep(0.1) if isinstance(awaitable.pk, int): self.logger.info( @@ -453,11 +460,10 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: self.update_task_state(awaitable.workgraph_pk, awaitable.key) # try to resume the workgraph, if the workgraph is already resumed # by other awaitable, this will not work - try: - print("Resume scheduler.") - self.resume() - except Exception as e: - print("Error in resume: ", e) + # try: + # self.resume() + # except Exception as e: + # print(e) def _build_process_label(self) -> str: """Use the workgraph name as the process label.""" From 56ef5065049520969a2236b867b1397443483337 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Wed, 28 Aug 2024 01:21:58 +0200 Subject: [PATCH 06/14] Add SchedulerClient, use circus to start scheduler --- aiida_workgraph/cli/cmd_scheduler.py | 227 +++++++++++------- aiida_workgraph/engine/scheduler/__init__.py | 3 + aiida_workgraph/engine/scheduler/client.py | 209 ++++++++++++++++ .../engine/{ => scheduler}/scheduler.py | 0 4 files changed, 347 insertions(+), 92 deletions(-) create mode 100644 aiida_workgraph/engine/scheduler/__init__.py create mode 100644 aiida_workgraph/engine/scheduler/client.py rename aiida_workgraph/engine/{ => scheduler}/scheduler.py (100%) diff --git a/aiida_workgraph/cli/cmd_scheduler.py b/aiida_workgraph/cli/cmd_scheduler.py index d265012e..556c1ba2 100644 --- a/aiida_workgraph/cli/cmd_scheduler.py +++ b/aiida_workgraph/cli/cmd_scheduler.py @@ -1,11 +1,10 @@ from aiida_workgraph.cli.cmd_workgraph import workgraph -from aiida import orm import click -import os from pathlib import Path -from aiida.cmdline.utils import echo -from .cmd_graph import REPAIR_INSTRUCTIONS - +from aiida.cmdline.utils import decorators, echo +from aiida.cmdline.params import options +from aiida_workgraph.engine.scheduler.client import get_scheduler_client +import sys REACT_PORT = "3000" @@ -29,107 +28,151 @@ def scheduler(): """Commands to manage the scheduler process.""" -@scheduler.command() -def start_worker(): - """Start the scheduler application.""" - from aiida_workgraph.engine.launch import start_scheduler_worker +# @scheduler.command() +# def worker(): +# """Start the scheduler application.""" +# from aiida_workgraph.engine.launch import start_scheduler_worker - click.echo("Starting the scheduler worker...") +# click.echo("Starting the scheduler worker...") - start_scheduler_worker() +# start_scheduler_worker() @scheduler.command() -def start(): +@click.option("--foreground", is_flag=True, help="Run in foreground.") +@options.TIMEOUT(default=None, required=False, type=int) +@decorators.with_dbenv() +@decorators.requires_broker +@decorators.check_circus_zmq_version +def start(foreground, timeout): """Start the scheduler application.""" - from aiida_workgraph.engine.scheduler import WorkGraphScheduler - from aiida.engine import submit click.echo("Starting the scheduler process...") - pid_file_path = get_pid_file_path() - # if the PID file already exists, check if the process is running - if pid_file_path.exists(): - with open(pid_file_path, "r") as pid_file: - for line in pid_file: - _, pid = line.strip().split(":") - if pid: - try: - node = orm.load_node(pid) - if node.is_sealed: - click.echo( - "PID file exists but no running scheduler process found." - ) - else: - click.echo( - f"Scheduler process with PID {node.pk} already running." - ) - return - except Exception: - click.echo( - "PID file exists but no running scheduler process found." - ) - - with open(pid_file_path, "w") as pid_file: - node = submit(WorkGraphScheduler) - pid_file.write(f"Scheduler:{node.pk}\n") - click.echo(f"Scheduler process started with PID {node.pk}.") + try: + client = get_scheduler_client() + client.start_daemon(foreground=foreground) + except Exception as exception: + echo.echo(f"Failed to start the scheduler process: {exception}") @scheduler.command() -def stop(): - """Stop the scheduler application.""" - from aiida.engine.processes import control - - pid_file_path = get_pid_file_path() - - if not pid_file_path.exists(): - click.echo("No running scheduler application found.") - return - - with open(pid_file_path, "r") as pid_file: - for line in pid_file: - _, pid = line.strip().split(":") - if pid: - click.confirm( - "Are you sure you want to kill the scheduler process?", abort=True - ) - process = orm.load_node(pid) - try: - message = "Killed through `verdi process kill`" - control.kill_processes( - [process], - timeout=5, - wait=True, - message=message, - ) - except control.ProcessTimeoutException as exception: - echo.echo_critical(f"{exception}\n{REPAIR_INSTRUCTIONS}") - os.remove(pid_file_path) +@click.option("--no-wait", is_flag=True, help="Do not wait for confirmation.") +@click.option("--all", "all_profiles", is_flag=True, help="Stop all daemons.") +@options.TIMEOUT(default=None, required=False, type=int) +@decorators.requires_broker +@click.pass_context +def stop(ctx, no_wait, all_profiles, timeout): + """Stop the scheduler daemon. + + Returns exit code 0 if the daemon was shut down successfully (or was not running), non-zero if there was an error. + """ + if all_profiles is True: + profiles = [ + profile + for profile in ctx.obj.config.profiles + if not profile.is_test_profile + ] + else: + profiles = [ctx.obj.profile] + + for profile in profiles: + echo.echo("Profile: ", fg=echo.COLORS["report"], bold=True, nl=False) + echo.echo(f"{profile.name}", bold=True) + echo.echo("Stopping the daemon... ", nl=False) + try: + client = get_scheduler_client() + client.stop_daemon(wait=not no_wait, timeout=timeout) + except Exception as exception: + echo.echo_error(f"Failed to stop the daemon: {exception}") + + +@scheduler.command(hidden=True) +@click.option("--foreground", is_flag=True, help="Run in foreground.") +@decorators.with_dbenv() +@decorators.requires_broker +@decorators.check_circus_zmq_version +def start_circus(foreground): + """This will actually launch the circus daemon, either daemonized in the background or in the foreground. + + If run in the foreground all logs are redirected to stdout. + + .. note:: this should not be called directly from the commandline! + """ + + get_scheduler_client()._start_daemon(foreground=foreground) @scheduler.command() -def status(): - """Check the status of the scheduler application.""" - from aiida.orm import QueryBuilder - from aiida_workgraph.engine.scheduler import WorkGraphScheduler - - qb = QueryBuilder() - projections = ["id"] - filters = { - "or": [ - {"attributes.sealed": False}, - {"attributes": {"!has_key": "sealed"}}, +@click.option("--all", "all_profiles", is_flag=True, help="Show status of all daemons.") +@options.TIMEOUT(default=None, required=False, type=int) +@click.pass_context +@decorators.requires_loaded_profile() +@decorators.requires_broker +def status(ctx, all_profiles, timeout): + """Print the status of the scheduler daemon. + + Returns exit code 0 if all requested daemons are running, else exit code 3. + """ + from tabulate import tabulate + + from aiida.cmdline.utils.common import format_local_time + from aiida.engine.daemon.client import DaemonException + + if all_profiles is True: + profiles = [ + profile + for profile in ctx.obj.config.profiles + if not profile.is_test_profile ] - } - qb.append( - WorkGraphScheduler, - filters=filters, - project=projections, - tag="process", - ) - results = qb.all() - if len(results) == 0: - click.echo("No scheduler found. Please start the scheduler first.") else: - click.echo(f"Scheduler process is running with PID: {results[0][0]}") + profiles = [ctx.obj.profile] + + daemons_running = [] + + for profile in profiles: + client = get_scheduler_client(profile.name) + echo.echo("Profile: ", fg=echo.COLORS["report"], bold=True, nl=False) + echo.echo(f"{profile.name}", bold=True) + + try: + client.get_status(timeout=timeout) + except DaemonException as exception: + echo.echo_error(str(exception)) + daemons_running.append(False) + continue + + worker_response = client.get_worker_info() + daemon_response = client.get_daemon_info() + + workers = [] + for pid, info in worker_response["info"].items(): + if isinstance(info, dict): + row = [ + pid, + info["mem"], + info["cpu"], + format_local_time(info["create_time"]), + ] + else: + row = [pid, "-", "-", "-"] + workers.append(row) + + if workers: + workers_info = tabulate( + workers, headers=["PID", "MEM %", "CPU %", "started"], tablefmt="simple" + ) + else: + workers_info = ( + "--> No workers are running. Use `verdi daemon incr` to start some!\n" + ) + + start_time = format_local_time(daemon_response["info"]["create_time"]) + echo.echo( + f'Daemon is running as PID {daemon_response["info"]["pid"]} since {start_time}\n' + f"Active workers [{len(workers)}]:\n{workers_info}\n" + "Use `verdi daemon [incr | decr] [num]` to increase / decrease the number of workers" + ) + + if not all(daemons_running): + sys.exit(3) diff --git a/aiida_workgraph/engine/scheduler/__init__.py b/aiida_workgraph/engine/scheduler/__init__.py new file mode 100644 index 00000000..95a95abf --- /dev/null +++ b/aiida_workgraph/engine/scheduler/__init__.py @@ -0,0 +1,3 @@ +from .scheduler import WorkGraphScheduler + +__all__ = ("WorkGraphScheduler",) diff --git a/aiida_workgraph/engine/scheduler/client.py b/aiida_workgraph/engine/scheduler/client.py new file mode 100644 index 00000000..f6ad9593 --- /dev/null +++ b/aiida_workgraph/engine/scheduler/client.py @@ -0,0 +1,209 @@ +from aiida.engine.daemon.client import DaemonClient +import shutil +from aiida.manage.manager import get_manager +from aiida.common.exceptions import ConfigurationError +import os + +WORKGRAPH_BIN = shutil.which("workgraph") + + +class SchedulerClient(DaemonClient): + """Client for interacting with the scheduler daemon.""" + + _DAEMON_NAME = "scheduler-{name}" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def _workgraph_bin(self) -> str: + """Return the absolute path to the ``verdi`` binary. + + :raises ConfigurationError: If the path to ``verdi`` could not be found + """ + if WORKGRAPH_BIN is None: + raise ConfigurationError( + "Unable to find 'verdi' in the path. Make sure that you are working " + "in a virtual environment, or that at least the 'verdi' executable is on the PATH" + ) + + return WORKGRAPH_BIN + + @property + def filepaths(self): + """Return the filepaths used by this profile. + + :return: a dictionary of filepaths + """ + from aiida.manage.configuration.settings import DAEMON_DIR, DAEMON_LOG_DIR + + return { + "circus": { + "log": str( + DAEMON_LOG_DIR / f"circus-scheduler-{self.profile.name}.log" + ), + "pid": str(DAEMON_DIR / f"circus-scheduler-{self.profile.name}.pid"), + "port": str(DAEMON_DIR / f"circus-scheduler-{self.profile.name}.port"), + "socket": { + "file": str( + DAEMON_DIR / f"circus-scheduler-{self.profile.name}.sockets" + ), + "controller": "circus.c.sock", + "pubsub": "circus.p.sock", + "stats": "circus.s.sock", + }, + }, + "daemon": { + "log": str(DAEMON_LOG_DIR / f"aiida-{self.profile.name}.log"), + "pid": str(DAEMON_DIR / f"aiida-{self.profile.name}.pid"), + }, + } + + @property + def circus_log_file(self) -> str: + return self.filepaths["circus"]["log"] + + @property + def circus_pid_file(self) -> str: + return self.filepaths["circus"]["pid"] + + @property + def circus_port_file(self) -> str: + return self.filepaths["circus"]["port"] + + @property + def circus_socket_file(self) -> str: + return self.filepaths["circus"]["socket"]["file"] + + @property + def circus_socket_endpoints(self) -> dict[str, str]: + return self.filepaths["circus"]["socket"] + + @property + def daemon_log_file(self) -> str: + return self.filepaths["daemon"]["log"] + + @property + def daemon_pid_file(self) -> str: + return self.filepaths["daemon"]["pid"] + + def cmd_start_daemon( + self, number_workers: int = 1, foreground: bool = False + ) -> list[str]: + """Return the command to start the daemon. + + :param number_workers: Number of daemon workers to start. + :param foreground: Whether to launch the subprocess in the background or not. + """ + command = [ + self._workgraph_bin, + "-p", + self.profile.name, + "scheduler", + "start-circus", + ] + + if foreground: + command.append("--foreground") + + return command + + @property + def cmd_start_daemon_worker(self) -> list[str]: + """Return the command to start a daemon worker process.""" + return [self._workgraph_bin, "-p", self.profile.name, "scheduler", "worker"] + + def _start_daemon(self, foreground: bool = False) -> None: + """Start the daemon. + + .. warning:: This will daemonize the current process and put it in the background. It is most likely not what + you want to call if you want to start the daemon from the Python API. Instead you probably will want to use + the :meth:`aiida.engine.daemon.client.DaemonClient.start_daemon` function instead. + + :param number_workers: Number of daemon workers to start. + :param foreground: Whether to launch the subprocess in the background or not. + """ + from circus import get_arbiter + from circus import logger as circus_logger + from circus.circusd import daemonize + from circus.pidfile import Pidfile + from circus.util import check_future_exception_and_log, configure_logger + + loglevel = self.loglevel + logoutput = "-" + + if not foreground: + logoutput = self.circus_log_file + + arbiter_config = { + "controller": self.get_controller_endpoint(), + "pubsub_endpoint": self.get_pubsub_endpoint(), + "stats_endpoint": self.get_stats_endpoint(), + "logoutput": logoutput, + "loglevel": loglevel, + "debug": False, + "statsd": True, + "pidfile": self.circus_pid_file, + "watchers": [ + { + "cmd": " ".join(self.cmd_start_daemon_worker), + "name": self.daemon_name, + "numprocesses": 1, + "virtualenv": self.virtualenv, + "copy_env": True, + "stdout_stream": { + "class": "FileStream", + "filename": self.daemon_log_file, + }, + "stderr_stream": { + "class": "FileStream", + "filename": self.daemon_log_file, + }, + "env": self.get_env(), + } + ], + } + + if not foreground: + daemonize() + + arbiter = get_arbiter(**arbiter_config) + pidfile = Pidfile(arbiter.pidfile) + pidfile.create(os.getpid()) + + # Configure the logger + loggerconfig = arbiter.loggerconfig or None + configure_logger(circus_logger, loglevel, logoutput, loggerconfig) + + # Main loop + should_restart = True + + while should_restart: + try: + future = arbiter.start() + should_restart = False + if check_future_exception_and_log(future) is None: + should_restart = arbiter._restarting + except Exception as exception: + # Emergency stop + arbiter.loop.run_sync(arbiter._emergency_stop) + raise exception + except KeyboardInterrupt: + pass + finally: + arbiter = None + if pidfile is not None: + pidfile.unlink() + + +def get_scheduler_client(profile_name: str | None = None) -> "SchedulerClient": + """Return the daemon client for the given profile or the current profile if not specified. + + :param profile_name: Optional profile name to load. + :return: The daemon client. + + :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found. + :raises aiida.common.ProfileConfigurationError: if the given profile does not exist. + """ + profile = get_manager().load_profile(profile_name) + return SchedulerClient(profile) diff --git a/aiida_workgraph/engine/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py similarity index 100% rename from aiida_workgraph/engine/scheduler.py rename to aiida_workgraph/engine/scheduler/scheduler.py From 7916e2716a40ba10be13e790b7ab063bb52bae02 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Wed, 28 Aug 2024 15:11:42 +0200 Subject: [PATCH 07/14] Continue Scheduler process when start the scheduler daemon --- aiida_workgraph/cli/cmd_scheduler.py | 19 +++++------ aiida_workgraph/engine/launch.py | 31 ++++++++++++++---- aiida_workgraph/engine/scheduler/client.py | 20 ++++++++++++ aiida_workgraph/engine/scheduler/scheduler.py | 1 - aiida_workgraph/engine/utils.py | 20 ------------ aiida_workgraph/workgraph.py | 32 ++++++++++++++++--- pyproject.toml | 2 +- 7 files changed, 81 insertions(+), 44 deletions(-) diff --git a/aiida_workgraph/cli/cmd_scheduler.py b/aiida_workgraph/cli/cmd_scheduler.py index 556c1ba2..d8c52d09 100644 --- a/aiida_workgraph/cli/cmd_scheduler.py +++ b/aiida_workgraph/cli/cmd_scheduler.py @@ -28,14 +28,14 @@ def scheduler(): """Commands to manage the scheduler process.""" -# @scheduler.command() -# def worker(): -# """Start the scheduler application.""" -# from aiida_workgraph.engine.launch import start_scheduler_worker +@scheduler.command() +def worker(): + """Start the scheduler application.""" + from aiida_workgraph.engine.launch import start_scheduler_worker -# click.echo("Starting the scheduler worker...") + click.echo("Starting the scheduler worker...") -# start_scheduler_worker() + start_scheduler_worker() @scheduler.command() @@ -49,11 +49,8 @@ def start(foreground, timeout): click.echo("Starting the scheduler process...") - try: - client = get_scheduler_client() - client.start_daemon(foreground=foreground) - except Exception as exception: - echo.echo(f"Failed to start the scheduler process: {exception}") + client = get_scheduler_client() + client.start_daemon(foreground=foreground) @scheduler.command() diff --git a/aiida_workgraph/engine/launch.py b/aiida_workgraph/engine/launch.py index 02f21499..441c9204 100644 --- a/aiida_workgraph/engine/launch.py +++ b/aiida_workgraph/engine/launch.py @@ -236,11 +236,17 @@ def start_scheduler_worker(foreground: bool = False) -> None: import sys from aiida.common.log import configure_logging - from aiida.engine.daemon.client import get_daemon_client from aiida.manage import get_config_option, get_manager from aiida_workgraph.engine.scheduler import WorkGraphScheduler + from aiida_workgraph.engine.scheduler.client import ( + get_scheduler_client, + get_scheduler, + ) + from aiida.engine.processes.launcher import ProcessLauncher + from aiida.engine import persistence + from plumpy.persistence import LoadSaveContext - daemon_client = get_daemon_client() + daemon_client = get_scheduler_client() configure_logging( daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file ) @@ -266,10 +272,23 @@ def start_scheduler_worker(foreground: bool = False) -> None: # https://github.com/python/mypy/issues/12557 runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) # type: ignore[misc] - print("runner", runner) - process_inited = instantiate_process(runner, WorkGraphScheduler) - runner.loop.create_task(process_inited.step_until_terminated()) - print("node", process_inited.node) + try: + running_scheduler = get_scheduler() + runner_loop = runner.loop + task_receiver = ProcessLauncher( + loop=runner_loop, + persister=manager.get_persister(), + load_context=LoadSaveContext(runner=runner), + loader=persistence.get_object_loader(), + ) + asyncio.run( + task_receiver._continue( + communicator=None, pid=running_scheduler, nowait=True + ) + ) + except ValueError: + process_inited = instantiate_process(runner, WorkGraphScheduler) + runner.loop.create_task(process_inited.step_until_terminated()) try: LOGGER.info("Starting a daemon worker") diff --git a/aiida_workgraph/engine/scheduler/client.py b/aiida_workgraph/engine/scheduler/client.py index f6ad9593..e32dc4cb 100644 --- a/aiida_workgraph/engine/scheduler/client.py +++ b/aiida_workgraph/engine/scheduler/client.py @@ -207,3 +207,23 @@ def get_scheduler_client(profile_name: str | None = None) -> "SchedulerClient": """ profile = get_manager().load_profile(profile_name) return SchedulerClient(profile) + + +def get_scheduler(): + from aiida.orm import QueryBuilder + from aiida_workgraph.engine.scheduler import WorkGraphScheduler + + qb = QueryBuilder() + projections = ["id"] + filters = { + "or": [ + {"attributes.sealed": False}, + {"attributes": {"!has_key": "sealed"}}, + ] + } + qb.append(WorkGraphScheduler, filters=filters, project=projections, tag="process") + results = qb.all() + if len(results) == 0: + raise ValueError("No scheduler found. Please start the scheduler first.") + scheduler_id = results[0][0] + return scheduler_id diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py index 3c29cec6..7ae1f55d 100644 --- a/aiida_workgraph/engine/scheduler/scheduler.py +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -160,7 +160,6 @@ def load_instance_state( self._update_process_status() self.resume() # For other awaitables, because they exist in the db, we only need to re-register the callbacks - self.ctx._workgraph[pk]["_awaitable_actions"] = [] self._action_awaitables() def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index e35e10d4..82332a78 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -139,23 +139,3 @@ def prepare_for_shell_task(task: dict, kwargs: dict) -> dict: "metadata": metadata or {}, } return inputs - - -def get_scheduler(): - from aiida.orm import QueryBuilder - from aiida_workgraph.engine.scheduler import WorkGraphScheduler - - qb = QueryBuilder() - projections = ["id"] - filters = { - "or": [ - {"attributes.sealed": False}, - {"attributes": {"!has_key": "sealed"}}, - ] - } - qb.append(WorkGraphScheduler, filters=filters, project=projections, tag="process") - results = qb.all() - if len(results) == 0: - raise ValueError("No scheduler found. Please start the scheduler first.") - scheduler_id = results[0][0] - return scheduler_id diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 7b60ab60..95212640 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -124,7 +124,14 @@ def submit( restart (bool): Restart the process, and reset the modified tasks, then only re-run the modified tasks. new (bool): Submit a new process. """ - from aiida_workgraph.engine.utils import get_scheduler + from aiida_workgraph.engine.scheduler.client import get_scheduler + + if to_scheduler: + try: + get_scheduler() + except ValueError as e: + print(e) + return # set task inputs if inputs is not None: @@ -138,8 +145,7 @@ def submit( if self.process.process_state.value.upper() not in ["CREATED"]: raise ValueError(f"Process {self.process.pk} has already been submitted.") if to_scheduler: - scheduler_pk = get_scheduler() - self.continue_process_in_scheduler(scheduler_pk) + self.continue_process_in_scheduler() else: self.continue_process() # as long as we submit the process, it is a new submission, we should set restart_process to None @@ -422,11 +428,27 @@ def continue_process(self): process_controller = get_manager().get_process_controller() process_controller.continue_process(self.pk) - def continue_process_in_scheduler(self, scheduler_pk: int = 122744): + def continue_process_in_scheduler(self): """Ask the scheduler to pick up the process from the database and run it.""" from aiida_workgraph.utils.control import create_task_action + from aiida_workgraph.engine.scheduler.client import get_scheduler + import kiwipy - create_task_action(scheduler_pk, [self.pk], action="launch_workgraph") + try: + scheduler_pk = get_scheduler() + create_task_action(scheduler_pk, [self.pk], action="launch_workgraph") + except ValueError: + print( + """Scheduler is not running. +Please start the scheduler first with `aiida-workgraph scheduler start`""" + ) + except kiwipy.exceptions.UnroutableError: + print( + """Scheduler exists, but the daemon is not running. +Please start the scheduler first with `aiida-workgraph scheduler start`""" + ) + except Exception as e: + print("""An unexpected error occurred:""", e) def play(self): import os diff --git a/pyproject.toml b/pyproject.toml index a13f5e48..4f76b8aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points.'aiida.workflows'] "workgraph.engine" = "aiida_workgraph.engine.workgraph:WorkGraphEngine" -"workgraph.scheduler" = "aiida_workgraph.engine.scheduler:WorkGraphScheduler" +"workgraph.scheduler" = "aiida_workgraph.engine.scheduler.scheduler:WorkGraphScheduler" [project.entry-points."aiida.data"] "workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData" From c36d72b747181a06d9f28027158b0ed0e948eba6 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Wed, 28 Aug 2024 18:05:50 +0200 Subject: [PATCH 08/14] Add checkpoint and docs --- aiida_workgraph/engine/scheduler/client.py | 7 +- aiida_workgraph/engine/scheduler/scheduler.py | 23 ++ docs/source/howto/scheduler.ipynb | 300 ++++++++++++++++++ 3 files changed, 327 insertions(+), 3 deletions(-) create mode 100644 docs/source/howto/scheduler.ipynb diff --git a/aiida_workgraph/engine/scheduler/client.py b/aiida_workgraph/engine/scheduler/client.py index e32dc4cb..1774eb64 100644 --- a/aiida_workgraph/engine/scheduler/client.py +++ b/aiida_workgraph/engine/scheduler/client.py @@ -3,6 +3,7 @@ from aiida.manage.manager import get_manager from aiida.common.exceptions import ConfigurationError import os +from typing import Optional WORKGRAPH_BIN = shutil.which("workgraph") @@ -54,8 +55,8 @@ def filepaths(self): }, }, "daemon": { - "log": str(DAEMON_LOG_DIR / f"aiida-{self.profile.name}.log"), - "pid": str(DAEMON_DIR / f"aiida-{self.profile.name}.pid"), + "log": str(DAEMON_LOG_DIR / f"aiida-scheduler-{self.profile.name}.log"), + "pid": str(DAEMON_DIR / f"aiida-scheduler-{self.profile.name}.pid"), }, } @@ -196,7 +197,7 @@ def _start_daemon(self, foreground: bool = False) -> None: pidfile.unlink() -def get_scheduler_client(profile_name: str | None = None) -> "SchedulerClient": +def get_scheduler_client(profile_name: Optional[str] = None) -> "SchedulerClient": """Return the daemon client for the given profile or the current profile if not specified. :param profile_name: Optional profile name to load. diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py index 7ae1f55d..e18cb225 100644 --- a/aiida_workgraph/engine/scheduler/scheduler.py +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -161,6 +161,18 @@ def load_instance_state( self.resume() # For other awaitables, because they exist in the db, we only need to re-register the callbacks self._action_awaitables() + # load checkpoint + launched_workgraphs = self.node.base.extras.get("_launched_workgraphs", []) + for pk in launched_workgraphs: + print("load workgraph: ", pk) + node = load_node(pk) + wgdata = node.base.extras.get("_workgraph", None) + if wgdata is None: + self.launch_workgraph(pk) + else: + self.ctx._workgraph[pk] = deserialize_unsafe(wgdata) + print("continue workgraph: ", pk) + self.continue_workgraph(pk) def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: """ @@ -478,6 +490,7 @@ def setup(self) -> None: """Setup the variables in the context.""" # track if the awaitable callback is added to the runner + self.ctx.launched_workgraphs = [] self.ctx._workgraph = {} self.ctx._max_number_awaitables = 10000 awaitable = Awaitable( @@ -501,6 +514,9 @@ def launch_workgraph(self, pk: str) -> None: """Launch the workgraph.""" # create the workgraph process self.report(f"Launch workgraph: {pk}") + # append the pk to the self.node.base.extras + self.ctx.launched_workgraphs.append(pk) + self.node.base.extras.set("_launched_workgraphs", self.ctx.launched_workgraphs) self.init_ctx_workgraph(pk) self.ctx._workgraph[pk]["_node"].set_process_state(Running.LABEL) self.init_task_results(pk) @@ -795,6 +811,7 @@ def update_task_state( task.setdefault("results", None) self.update_parent_task_state(pk, name) + self.save_workgraph_checkpoint(pk) if continue_workgraph: self.continue_workgraph(pk) @@ -817,6 +834,12 @@ def set_normal_task_results(self, pk, name, results): self.report(f"Workgraph: {pk}, Task: {name} finished.") self.update_parent_task_state(pk, name) + def save_workgraph_checkpoint(self, pk: int): + """Save the workgraph checkpoint.""" + self.ctx._workgraph[pk]["_node"].set_extra( + "_checkpoint", serialize(self.ctx._workgraph[pk]) + ) + def update_parent_task_state(self, pk, name: str) -> None: """Update parent task state.""" parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"] diff --git a/docs/source/howto/scheduler.ipynb b/docs/source/howto/scheduler.ipynb new file mode 100644 index 00000000..ad046289 --- /dev/null +++ b/docs/source/howto/scheduler.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scheduler\n", + "\n", + "Start a scheduler daemon:\n", + "\n", + "```console\n", + "workgraph scheduler start\n", + "```\n", + "\n", + "Check the status of the scheduler:\n", + "\n", + "```console\n", + "workgraph scheduler status\n", + "```\n", + "\n", + "Stop the scheduler:\n", + "\n", + "```console\n", + "workgraph scheduler stop\n", + "```\n", + "\n", + "## Submit workgraph to the scheduler\n", + "Set `to_scheduler` to `True` when submitting a workgraph to the scheduler:\n", + "\n", + "```python\n", + "wg.submit(to_scheduler=True)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WorkGraph process created, PK: 134971\n", + "State of WorkGraph : FINISHED\n", + "Result of add2 : 4\n" + ] + } + ], + "source": [ + "from aiida_workgraph import WorkGraph, task\n", + "from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n", + "from aiida import load_profile, orm\n", + "\n", + "load_profile()\n", + "\n", + "@task.calcfunction()\n", + "def add(x, y):\n", + " return x + y\n", + "\n", + "code = orm.load_code(\"add@localhost\")\n", + "\n", + "wg = WorkGraph(f\"test_scheduler\")\n", + "add1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\", x=1, y=2, code=code)\n", + "add2 = wg.add_task(ArithmeticAddCalculation, name=\"add2\", x=1, y=add1.outputs[\"sum\"], code=code)\n", + "wg.submit(to_scheduler=True,\n", + " wait=True)\n", + "print(\"State of WorkGraph : {}\".format(wg.state))\n", + "print('Result of add2 : {}'.format(wg.tasks[\"add2\"].node.outputs.sum.value))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate node graph from the AiiDA process,and we can see the provenance graph of the workgraph:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "N134971\n", + "\n", + "WorkGraph<test_scheduler> (134971)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N134974\n", + "\n", + "ArithmeticAddCalculation (134974)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N134971->N134974\n", + "\n", + "\n", + "CALL_CALC\n", + "add1\n", + "\n", + "\n", + "\n", + "N134979\n", + "\n", + "ArithmeticAddCalculation (134979)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N134971->N134979\n", + "\n", + "\n", + "CALL_CALC\n", + "add2\n", + "\n", + "\n", + "\n", + "N37\n", + "\n", + "InstalledCode (37)\n", + "add@localhost\n", + "\n", + "\n", + "\n", + "N37->N134971\n", + "\n", + "\n", + "INPUT_WORK\n", + "wg__tasks__add2__properties__code__value\n", + "\n", + "\n", + "\n", + "N37->N134971\n", + "\n", + "\n", + "INPUT_WORK\n", + "wg__tasks__add1__properties__code__value\n", + "\n", + "\n", + "\n", + "N134975\n", + "\n", + "RemoteData (134975)\n", + "@localhost\n", + "\n", + "\n", + "\n", + "N134974->N134975\n", + "\n", + "\n", + "CREATE\n", + "remote_folder\n", + "\n", + "\n", + "\n", + "N134976\n", + "\n", + "FolderData (134976)\n", + "\n", + "\n", + "\n", + "N134974->N134976\n", + "\n", + "\n", + "CREATE\n", + "retrieved\n", + "\n", + "\n", + "\n", + "N134977\n", + "\n", + "Int (134977)\n", + "\n", + "\n", + "\n", + "N134974->N134977\n", + "\n", + "\n", + "CREATE\n", + "sum\n", + "\n", + "\n", + "\n", + "N134977->N134979\n", + "\n", + "\n", + "INPUT_CALC\n", + "y\n", + "\n", + "\n", + "\n", + "N134980\n", + "\n", + "RemoteData (134980)\n", + "@localhost\n", + "\n", + "\n", + "\n", + "N134979->N134980\n", + "\n", + "\n", + "CREATE\n", + "remote_folder\n", + "\n", + "\n", + "\n", + "N134981\n", + "\n", + "FolderData (134981)\n", + "\n", + "\n", + "\n", + "N134979->N134981\n", + "\n", + "\n", + "CREATE\n", + "retrieved\n", + "\n", + "\n", + "\n", + "N134982\n", + "\n", + "Int (134982)\n", + "\n", + "\n", + "\n", + "N134979->N134982\n", + "\n", + "\n", + "CREATE\n", + "sum\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from aiida_workgraph.utils import generate_node_graph\n", + "generate_node_graph(wg.pk)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aiida", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 9348801d5ed08124f5550a6860ec3565d5da5ad6 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Sun, 1 Sep 2024 07:43:01 +0200 Subject: [PATCH 09/14] - Delete a workgraph data when it is finished - Pickup the Scheduler process instead of launching a new one. - submit a workgraph inside the scheduler - Move report from Scheduler process to the workgraph process --- aiida_workgraph/engine/launch.py | 69 ++------- aiida_workgraph/engine/scheduler/scheduler.py | 140 ++++++++++-------- aiida_workgraph/engine/utils.py | 51 +++++++ aiida_workgraph/engine/workgraph.py | 5 +- aiida_workgraph/tasks/test.py | 17 +++ aiida_workgraph/workgraph.py | 10 +- docs/source/howto/index.rst | 1 + tests/conftest.py | 40 +++++ tests/test_scheduler.py | 22 +++ 9 files changed, 231 insertions(+), 124 deletions(-) create mode 100644 tests/test_scheduler.py diff --git a/aiida_workgraph/engine/launch.py b/aiida_workgraph/engine/launch.py index 441c9204..96324c75 100644 --- a/aiida_workgraph/engine/launch.py +++ b/aiida_workgraph/engine/launch.py @@ -11,16 +11,11 @@ from aiida.engine.processes.builder import ProcessBuilder from aiida.engine.processes.functions import get_stack_size from aiida.engine.processes.process import Process -from aiida.engine.utils import prepare_inputs, is_process_function +from aiida.engine.utils import prepare_inputs +from .utils import instantiate_process import signal import sys -import inspect -from typing import ( - Type, - Union, -) - from aiida.manage import get_manager @@ -40,7 +35,7 @@ def run_get_node( :param kwargs: input keyword arguments to construct the FunctionProcess :return: tuple of the outputs of the process and the process node """ - parent_pid = kwargs.pop("_parent_pid", None) + parent_pid = kwargs.pop("parent_pid", None) frame_delta = 1000 frame_count = get_stack_size() stack_limit = sys.getrecursionlimit() @@ -70,10 +65,7 @@ def run_get_node( raise ValueError( f"{function.__name__} does not support these kwargs: {kwargs.keys()}" ) - if parent_pid: - process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid) - else: - process = process_class(inputs=inputs, runner=runner) + process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid) # Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner. # Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown current_runner = manager.get_runner() @@ -109,57 +101,14 @@ def kill_process(_num, _frame): return result, process.node -def instantiate_process( - runner: "Runner", - process: Union["Process", Type["Process"], "ProcessBuilder"], - _parent_pid=None, - **inputs, -) -> "Process": - """Return an instance of the process with the given inputs. The function can deal with various types - of the `process`: - - * Process instance: will simply return the instance - * ProcessBuilder instance: will instantiate the Process from the class and inputs defined within it - * Process class: will instantiate with the specified inputs - - If anything else is passed, a ValueError will be raised - - :param process: Process instance or class, CalcJobNode class or ProcessBuilder instance - :param inputs: the inputs for the process to be instantiated with - """ - - if isinstance(process, Process): - assert not inputs - assert runner is process.runner - return process - - if isinstance(process, ProcessBuilder): - builder = process - process_class = builder.process_class - inputs.update(**builder._inputs(prune=True)) - elif is_process_function(process): - process_class = process.process_class # type: ignore[attr-defined] - elif inspect.isclass(process) and issubclass(process, Process): - process_class = process - else: - raise ValueError( - f"invalid process {type(process)}, needs to be Process or ProcessBuilder" - ) - - if _parent_pid: - process = process_class(runner=runner, inputs=inputs, parent_pid=_parent_pid) - else: - process = process_class(runner=runner, inputs=inputs) - - return process - - def submit( process: TYPE_SUBMIT_PROCESS, inputs: dict[str, t.Any] | None = None, *, wait: bool = False, wait_interval: int = 5, + parent_pid: int | None = None, + runner: "Runner" | None = None, **kwargs: t.Any, ) -> ProcessNode: """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. @@ -177,8 +126,6 @@ def submit( :param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument. :return: the calculation node of the process """ - _parent_pid = kwargs.pop("_parent_pid", None) - runner = kwargs.pop("runner", None) inputs = prepare_inputs(inputs, **kwargs) # Submitting from within another process requires ``self.submit``` unless it is a work function, in which case the @@ -192,7 +139,7 @@ def submit( assert runner.controller is not None, "runner does not have a controller" process_inited = instantiate_process( - runner, process, _parent_pid=_parent_pid, **inputs + runner, process, parent_pid=parent_pid, **inputs ) # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this @@ -245,6 +192,7 @@ def start_scheduler_worker(foreground: bool = False) -> None: from aiida.engine.processes.launcher import ProcessLauncher from aiida.engine import persistence from plumpy.persistence import LoadSaveContext + from aiida.engine.daemon.worker import shutdown_worker daemon_client = get_scheduler_client() configure_logging( @@ -287,6 +235,7 @@ def start_scheduler_worker(foreground: bool = False) -> None: ) ) except ValueError: + print("Starting a new Scheduler") process_inited = instantiate_process(runner, WorkGraphScheduler) runner.loop.create_task(process_inited.step_until_terminated()) diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py index e18cb225..b7cd02fd 100644 --- a/aiida_workgraph/engine/scheduler/scheduler.py +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -33,6 +33,7 @@ from aiida_workgraph.task import Task from aiida_workgraph.utils import get_nested_dict, update_nested_dict from aiida_workgraph.executors.monitors import monitor +from aiida.common.log import LOG_LEVEL_REPORT if t.TYPE_CHECKING: from aiida.engine.runners import Runner # pylint: disable=unused-import @@ -58,6 +59,7 @@ def __init__( logger: logging.Logger | None = None, runner: "Runner" | None = None, enable_persistence: bool = True, + **kwargs: t.Any, ) -> None: """Construct a WorkGraph instance. @@ -68,7 +70,9 @@ def __init__( """ - super().__init__(inputs, logger, runner, enable_persistence=enable_persistence) + super().__init__( + inputs, logger, runner, enable_persistence=enable_persistence, **kwargs + ) self._awaitables: list[Awaitable] = [] self._context = AttributeDict() @@ -166,7 +170,7 @@ def load_instance_state( for pk in launched_workgraphs: print("load workgraph: ", pk) node = load_node(pk) - wgdata = node.base.extras.get("_workgraph", None) + wgdata = node.base.extras.get("_checkpoint", None) if wgdata is None: self.launch_workgraph(pk) else: @@ -526,7 +530,6 @@ def init_ctx_workgraph(self, pk: int) -> None: """Init the context from the workgraph data.""" from aiida_workgraph.utils import update_nested_dict - self.report(f"Init workgraph: {pk}") # read the latest workgraph data wgdata, node = self.read_wgdata_from_base(pk) self.ctx._workgraph[pk] = { @@ -548,7 +551,6 @@ def setup_ctx_workgraph(self, pk: int, wgdata: t.Dict[str, t.Any]) -> None: """setup the workgraph in the context.""" import cloudpickle as pickle - self.report(f"Setup workgraph: {pk}") self.ctx._workgraph[pk]["_tasks"] = wgdata.pop("tasks") self.ctx._workgraph[pk]["_links"] = wgdata.pop("links") self.ctx._workgraph[pk]["_connectivity"] = wgdata.pop("connectivity") @@ -556,6 +558,7 @@ def setup_ctx_workgraph(self, pk: int, wgdata: t.Dict[str, t.Any]) -> None: self.ctx._workgraph[pk]["_error_handlers"] = pickle.loads( wgdata.pop("error_handlers") ) + self.ctx._workgraph[pk]["_metadata"] = wgdata.pop("metadata") self.ctx._workgraph[pk]["_workgraph"] = wgdata self.ctx._workgraph[pk]["_awaitable_actions"] = [] @@ -720,16 +723,32 @@ def kill_task(self, pk, name: str) -> None: except Exception as e: self.logger.error(f"Error in killing task {name}: {e}") + def report(self, msg, pk=None): + """Report the message.""" + if pk: + self.ctx._workgraph[pk]["_node"].logger.log(LOG_LEVEL_REPORT, msg) + else: + super().report(msg) + def continue_workgraph(self, pk: int) -> None: + # if the workgraph is finished, skip + if pk not in self.ctx._workgraph: + # the workgraph is finished + return is_finished, _ = self.is_workgraph_finished(pk) if is_finished: - self.report(f"Workgraph {pk} finished.") + self.finalize_workgraph(pk) self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL) self.ctx._workgraph[pk]["_node"].set_exit_status(0) self.ctx._workgraph[pk]["_node"].seal() + # remove the workgraph from the context + del self.ctx._workgraph[pk] + self.ctx.launched_workgraphs.remove(pk) + self.node.base.extras.set( + "_launched_workgraphs", self.ctx.launched_workgraphs + ) return - self.report("Continue workgraph.") - # self.update_workgraph_from_base() + self.report("Continue.", pk) task_to_run = [] for name, task in self.ctx._workgraph[pk]["_tasks"].items(): # update task state @@ -749,11 +768,7 @@ def continue_workgraph(self, pk: int) -> None: if ready: task_to_run.append(name) # - self.report( - "tasks ready to run in WorkGraph {}, tasks: {}".format( - pk, ",".join(task_to_run) - ) - ) + self.report("tasks ready to run: {}".format(",".join(task_to_run)), pk) if len(task_to_run) > 0: self.run_tasks(pk, task_to_run) @@ -784,15 +799,15 @@ def update_task_state( task["results"][link.link_label] = link.node.outputs else: task["results"] = node.outputs - # self.ctx._new_data[name] = task["results"] + # self.ctx._workgraph[pk]["_new_data"][name] = task["results"] self.set_task_state_info(pk, task["name"], "state", "FINISHED") self.task_set_context(pk, name) - self.report(f"Workgraph: {pk}, Task: {name} finished.") + self.report(f"Task: {name} finished.", pk=pk) # all other states are considered as failed else: print(f"set task result: {name} failed") task["results"] = node.outputs - # self.ctx._new_data[name] = task["results"] + # self.ctx._workgraph[pk]["_new_data"][name] = task["results"] self.set_task_state_info(pk, task["name"], "state", "FAILED") # set child tasks state to SKIPPED self.set_tasks_state( @@ -800,20 +815,23 @@ def update_task_state( self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], "SKIPPED", ) - self.report(f"Workgraph: {pk}, Task: {name} failed.") + self.report(f"Task: {name} failed.", pk) self.run_error_handlers(pk, name) elif isinstance(node, orm.Data): task["results"] = {task["outputs"][0]["name"]: node} self.set_task_state_info(pk, task["name"], "state", "FINISHED") self.task_set_context(pk, name) - self.report(f"Workgraph: {pk}, Task: {name} finished.") + self.report(f"Task: {name} finished.", pk) else: task.setdefault("results", None) self.update_parent_task_state(pk, name) self.save_workgraph_checkpoint(pk) if continue_workgraph: - self.continue_workgraph(pk) + try: + self.continue_workgraph(pk) + except Exception as e: + print(e) def set_normal_task_results(self, pk, name, results): """Set the results of a normal task. @@ -831,12 +849,12 @@ def set_normal_task_results(self, pk, name, results): task["results"][task["outputs"][0]["name"]] = results self.task_set_context(pk, name) self.set_task_state_info(pk, name, "state", "FINISHED") - self.report(f"Workgraph: {pk}, Task: {name} finished.") + self.report(f"Task: {name} finished.", pk=pk) self.update_parent_task_state(pk, name) def save_workgraph_checkpoint(self, pk: int): """Save the workgraph checkpoint.""" - self.ctx._workgraph[pk]["_node"].set_extra( + self.ctx._workgraph[pk]["_node"].base.extras.set( "_checkpoint", serialize(self.ctx._workgraph[pk]) ) @@ -848,19 +866,20 @@ def update_parent_task_state(self, pk, name: str) -> None: "node_type" ].upper() if task_type == "WHILE": - self.update_while_task_state(parent_task[0]) + self.update_while_task_state(pk, parent_task[0]) elif task_type == "IF": - self.update_zone_task_state(parent_task[0]) + self.update_zone_task_state(pk, parent_task[0]) elif task_type == "ZONE": - self.update_zone_task_state(parent_task[0]) + self.update_zone_task_state(pk, parent_task[0]) - def update_while_task_state(self, name: str) -> None: + def update_while_task_state(self, pk: int, name: str) -> None: """Update while task state.""" finished, _ = self.are_childen_finished(name) if finished: self.report( - f"Wihle Task {name}: this iteration finished. Try to reset for the next iteration." + f"Wihle Task {name}: this iteration finished. Try to reset for the next iteration.", + pk, ) # reset the condition tasks for input in self.ctx._workgraph[pk]["_tasks"][name]["inputs"]: @@ -871,13 +890,13 @@ def update_while_task_state(self, name: str) -> None: # do not reset the execution count self.reset_task(name, reset_execution_count=False) - def update_zone_task_state(self, name: str) -> None: + def update_zone_task_state(self, pk: int, name: str) -> None: """Update zone task state.""" finished, _ = self.are_childen_finished(name) if finished: self.set_task_state_info(pk, name, "state", "FINISHED") self.update_parent_task_state(pk, name) - self.report(f"Task: {name} finished.") + self.report(f"Task: {name} finished.", pk) def should_run_while_task(self, pk: int, name: str) -> tuple[bool, t.Any]: """Check if the while task should run.""" @@ -933,7 +952,7 @@ def run_error_handlers(self, pk: int, task_name: str) -> None: handler = data["handler"] metadata = data["tasks"][task_name] if node.exit_code.status in metadata.get("exit_codes", []): - self.report(f"Run error handler: {metadata}") + self.report(f"Run error handler: {metadata}", pk) metadata.setdefault("retry", 0) if metadata["retry"] < metadata["max_retries"]: handler(self, task_name, **metadata.get("kwargs", {})) @@ -967,7 +986,7 @@ def is_workgraph_finished(self, pk) -> bool: is_finished = not should_run if is_finished and len(failed_tasks) > 0: message = f"WorkGraph finished, but tasks: {failed_tasks} failed. Thus all their child tasks are skipped." - self.report(message) + self.report(message, pk) result = ExitCode(302, message) else: result = None @@ -977,12 +996,11 @@ def check_while_conditions(self, pk: int) -> bool: """Check while conditions. Run all condition tasks and check if all the conditions are True. """ - self.report("Check while conditions.") if ( self.ctx._workgraph[pk]["_execution_count"] >= self.ctx._workgraph[pk]["_max_iteration"] ): - self.report("Max iteration reached.") + self.report("Max iteration reached.", pk) return False condition_tasks = [] for c in self.ctx._workgraph[pk]["conditions"]: @@ -1084,7 +1102,7 @@ def run_tasks( self.ctx._workgraph[pk]["_executed_tasks"].append(name) print("-" * 60) - self.report(f"Run task: {name}, type: {task['metadata']['node_type']}") + self.report(f"Run task: {name}, type: {task['metadata']['node_type']}", pk) executor, _ = get_executor(task["executor"]) # print("executor: ", executor) args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(pk, name) @@ -1112,7 +1130,7 @@ def run_tasks( results = create_data_node(executor, args, kwargs) self.set_task_state_info(pk, name, "process", results) self.update_task_state(pk, name) - self.ctx._new_data[name] = results + self.ctx._workgraph[pk]["_new_data"][name] = results if continue_workgraph: self.continue_workgraph(pk) elif task["metadata"]["node_type"].upper() in [ @@ -1121,7 +1139,7 @@ def run_tasks( ]: kwargs.setdefault("metadata", {}) kwargs["metadata"].update({"call_link_label": name}) - kwargs["_parent_pid"] = pk + kwargs["parent_pid"] = pk try: # since aiida 2.5.0, we need to use args_dict to pass the args to the run_get_node if var_kwargs is None: @@ -1145,7 +1163,7 @@ def run_tasks( self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], "SKIPPED", ) - self.report(f"Task: {name} failed.") + self.report(f"Task: {name} failed.", pk) # exclude the current tasks from the next run if continue_workgraph: self.continue_workgraph(pk) @@ -1153,11 +1171,11 @@ def run_tasks( # process = run_get_node(executor, *args, **kwargs) kwargs.setdefault("metadata", {}) kwargs["metadata"].update({"call_link_label": name}) - kwargs["_parent_pid"] = pk + kwargs["parent_pid"] = pk # transfer the args to kwargs if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") - self.report(f"Task {name} is created and paused.") + self.report(f"Task {name} is created and paused.", pk) process = create_and_pause_process( self.runner, executor, @@ -1179,20 +1197,23 @@ def run_tasks( wg.group_outputs = self.ctx._workgraph[pk]["_tasks"][name]["metadata"][ "group_outputs" ] - wg.parent_uuid = self.node.uuid - inputs = wg.prepare_inputs(metadata={"call_link_label": name}) - inputs["parent_pid"] = pk - process = launch.submit(WorkGraphEngine, inputs=inputs) - process.workgraph_pk = pk - self.set_task_state_info(pk, name, "process", process) - self.set_task_state_info(pk, name, "state", "RUNNING") - self.to_context(**{name: process}) + metadata = {"call_link_label": name} + try: + wg.save(metadata=metadata, parent_pid=pk) + process = wg.process + self.launch_workgraph(process.pk) + process.workgraph_pk = pk + self.set_task_state_info(pk, name, "process", process) + self.set_task_state_info(pk, name, "state", "RUNNING") + self.report(f"GraphBuilder {name} is created.", pk) + self.to_context(**{name: process}) + except Exception as e: + print("Error in launching the workgraph: ", e) elif task["metadata"]["node_type"].upper() in ["WORKGRAPH"]: from .utils import prepare_for_workgraph_task inputs, _ = prepare_for_workgraph_task(task, kwargs) - inputs["parent_pid"] = pk - process = launch.submit(WorkGraphEngine, inputs=inputs) + process = launch.submit(WorkGraphEngine, inputs=inputs, parent_pid=pk) process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.set_task_state_info(pk, name, "state", "RUNNING") @@ -1202,11 +1223,10 @@ def run_tasks( from .utils import prepare_for_python_task inputs = prepare_for_python_task(task, kwargs, var_kwargs) - inputs["parent_pid"] = pk # since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") - self.report(f"Task {name} is created and paused.") + self.report(f"Task {name} is created and paused.", pk) process = create_and_pause_process( self.runner, PythonJob, @@ -1216,7 +1236,7 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "CREATED") process = process.node else: - process = launch.submit(PythonJob, **inputs) + process = launch.submit(PythonJob, inputs=inputs, parent_pid=pk) self.set_task_state_info(pk, name, "state", "RUNNING") process.label = name process.workgraph_pk = pk @@ -1227,10 +1247,9 @@ def run_tasks( from .utils import prepare_for_shell_task inputs = prepare_for_shell_task(task, kwargs) - inputs["parent_pid"] = pk if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") - self.report(f"Task {name} is created and paused.") + self.report(f"Task {name} is created and paused.", pk) process = create_and_pause_process( self.runner, ShellJob, @@ -1240,7 +1259,7 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "CREATED") process = process.node else: - process = launch.submit(ShellJob, **inputs) + process = launch.submit(ShellJob, inputs=inputs, parent_pid=pk) self.set_task_state_info(pk, name, "state", "RUNNING") process.label = name process.workgraph_pk = pk @@ -1258,7 +1277,8 @@ def run_tasks( ) self.update_parent_task_state(pk, name) self.report( - f"While Task {name}: Condition not fullilled, task finished. Skip all its children." + f"While Task {name}: Condition not fullilled, task finished. Skip all its children.", + pk, ) else: task["execution_count"] += 1 @@ -1270,7 +1290,7 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "RUNNING") else: self.set_tasks_state(pk, task["children"], "SKIPPED") - self.update_zone_task_state(name) + self.update_zone_task_state(pk, name) self.continue_workgraph(pk) elif task["metadata"]["node_type"].upper() in ["ZONE"]: self.set_task_state_info(pk, name, "state", "RUNNING") @@ -1334,7 +1354,7 @@ def run_tasks( self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], "SKIPPED", ) - self.report(f"Task: {name} failed.") + self.report(f"Task: {name} failed.", pk) self.run_error_handlers(pk, name) if continue_workgraph: self.continue_workgraph(pk) @@ -1602,11 +1622,11 @@ def message_receive( # Didn't match any known intents raise RuntimeError("Unknown intent") - def finalize(self) -> t.Optional[ExitCode]: + def finalize_workgraph(self, pk: int) -> t.Optional[ExitCode]: """""" # expose outputs of the workgraph group_outputs = {} - for output in self.ctx._workgraph["metadata"]["group_outputs"]: + for output in self.ctx._workgraph[pk]["_metadata"]["group_outputs"]: names = output["from"].split(".", 1) if names[0] == "context": if len(names) == 1: @@ -1640,12 +1660,12 @@ def finalize(self) -> t.Optional[ExitCode]: ) self.out_many(group_outputs) # output the new data - self.out("new_data", self.ctx._new_data) + self.out("new_data", self.ctx._workgraph[pk]["_new_data"]) self.out( "execution_count", orm.Int(self.ctx._workgraph[pk]["_execution_count"]).store(), ) - self.report("Finalize workgraph.") + self.report("Finalize workgraph.", pk) for _, task in self.ctx._workgraph[pk]["_tasks"].items(): if self.get_task_state_info(pk, task["name"], "state") == "FAILED": return self.exit_codes.TASK_FAILED diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index 82332a78..e35f4fc5 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -1,6 +1,15 @@ from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes from aiida import orm from aiida.common.extendeddicts import AttributeDict +from aiida.engine.utils import is_process_function +from aiida.engine.processes.builder import ProcessBuilder +from aiida.engine.processes.process import Process + +import inspect +from typing import ( + Type, + Union, +) def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: @@ -139,3 +148,45 @@ def prepare_for_shell_task(task: dict, kwargs: dict) -> dict: "metadata": metadata or {}, } return inputs + + +def instantiate_process( + runner: "Runner", + process: Union["Process", Type["Process"], "ProcessBuilder"], + parent_pid=None, + **inputs, +) -> "Process": + """Return an instance of the process with the given inputs. The function can deal with various types + of the `process`: + + * Process instance: will simply return the instance + * ProcessBuilder instance: will instantiate the Process from the class and inputs defined within it + * Process class: will instantiate with the specified inputs + + If anything else is passed, a ValueError will be raised + + :param process: Process instance or class, CalcJobNode class or ProcessBuilder instance + :param inputs: the inputs for the process to be instantiated with + """ + + if isinstance(process, Process): + assert not inputs + assert runner is process.runner + return process + + if isinstance(process, ProcessBuilder): + builder = process + process_class = builder.process_class + inputs.update(**builder._inputs(prune=True)) + elif is_process_function(process): + process_class = process.process_class # type: ignore[attr-defined] + elif inspect.isclass(process) and issubclass(process, Process): + process_class = process + else: + raise ValueError( + f"invalid process {type(process)}, needs to be Process or ProcessBuilder" + ) + + process = process_class(runner=runner, inputs=inputs, parent_pid=parent_pid) + + return process diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index 1de8c292..b7e4ef1e 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -60,6 +60,7 @@ def __init__( logger: logging.Logger | None = None, runner: "Runner" | None = None, enable_persistence: bool = True, + **kwargs: t.Any, ) -> None: """Construct a WorkGraph instance. @@ -70,7 +71,9 @@ def __init__( """ - super().__init__(inputs, logger, runner, enable_persistence=enable_persistence) + super().__init__( + inputs, logger, runner, enable_persistence=enable_persistence, **kwargs + ) self._awaitables: list[Awaitable] = [] self._context = AttributeDict() diff --git a/aiida_workgraph/tasks/test.py b/aiida_workgraph/tasks/test.py index e0c2bfa4..78ab58f7 100644 --- a/aiida_workgraph/tasks/test.py +++ b/aiida_workgraph/tasks/test.py @@ -1,5 +1,6 @@ from typing import Dict from aiida_workgraph.task import Task +from aiida_workgraph.decorator import task class TestAdd(Task): @@ -114,3 +115,19 @@ def get_executor(self) -> Dict[str, str]: "name": "core.arithmetic.multiply_add", "type": "WorkflowFactory", } + + +@task.graph_builder() +def create_workgraph_recrusively(n): + from aiida_workgraph import WorkGraph + from aiida_workgraph.tasks.test import create_workgraph_recrusively + from aiida.calculations.arithmetic.add import ArithmeticAddCalculation + from aiida import orm + + code = orm.load_code("add@localhost") + wg = WorkGraph(f"n-{n}") + if n > 0: + wg.add_task(create_workgraph_recrusively, name=f"n_{n}", n=n - 1) + else: + wg.add_task(ArithmeticAddCalculation, code=code, x=1, y=2) + return wg diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 95212640..341cdae0 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -154,20 +154,24 @@ def submit( self.wait(timeout=timeout) return self.process - def save(self, metadata: Optional[Dict[str, Any]] = None) -> None: + def save( + self, metadata: Optional[Dict[str, Any]] = None, parent_pid: int = None + ) -> None: """Save the udpated workgraph to the process This is only used for a running workgraph. Save the AiiDA workgraph process and update the process status. """ from aiida.manage import manager - from aiida.engine.utils import instantiate_process + from aiida_workgraph.engine.utils import instantiate_process from aiida_workgraph.engine.workgraph import WorkGraphEngine inputs = self.prepare_inputs(metadata) if self.process is None: runner = manager.get_manager().get_runner() # init a process node - process_inited = instantiate_process(runner, WorkGraphEngine, **inputs) + process_inited = instantiate_process( + runner, WorkGraphEngine, parent_pid=parent_pid, **inputs + ) process_inited.runner.persister.save_checkpoint(process_inited) self.process = process_inited.node self.process_inited = process_inited diff --git a/docs/source/howto/index.rst b/docs/source/howto/index.rst index 274f74be..3991edfe 100644 --- a/docs/source/howto/index.rst +++ b/docs/source/howto/index.rst @@ -24,5 +24,6 @@ This section contains a collection of HowTos for various topics. queue cli control + scheduler transfer_workchain workchain_call_workgraph diff --git a/tests/conftest.py b/tests/conftest.py index 36b1fe61..efb2d1f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,46 @@ def fixture_localhost(aiida_localhost): return localhost +@pytest.fixture(scope="session") +def scheduler_client(aiida_profile): + """Return a daemon client for the configured test profile for the test session. + + The daemon will be automatically stopped at the end of the test session. + """ + from aiida_workgraph.engine.scheduler.client import SchedulerClient + from aiida.engine.daemon.client import ( + DaemonNotRunningException, + DaemonTimeoutException, + ) + + scheduler_client = SchedulerClient(aiida_profile) + + try: + yield scheduler_client + finally: + try: + scheduler_client.stop_daemon(wait=True) + except DaemonNotRunningException: + pass + # Give an additional grace period by manually waiting for the daemon to be stopped. In certain unit test + # scenarios, the built in wait time in ``scheduler_client.stop_daemon`` is not sufficient and even though the + # daemon is stopped, ``scheduler_client.is_daemon_running`` will return false for a little bit longer. + scheduler_client._await_condition( + lambda: not scheduler_client.is_daemon_running, + DaemonTimeoutException("The daemon failed to stop."), + ) + + +@pytest.fixture() +def started_scheduler_client(scheduler_client): + """Ensure that the daemon is running for the test profile and return the associated client.""" + if not scheduler_client.is_daemon_running: + scheduler_client.start_daemon() + assert scheduler_client.is_daemon_running + + yield scheduler_client + + @pytest.fixture def add_code(fixture_localhost): from aiida.orm import InstalledCode diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 00000000..f51a1320 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,22 @@ +import pytest +from typing import Callable +from aiida_workgraph import WorkGraph +from aiida.cmdline.utils.common import get_workchain_report +from aiida_workgraph.engine.scheduler.client import get_scheduler +from aiida import orm + + +@pytest.skip("Skip for now") +@pytest.mark.usefixtures("started_daemon_client") +def test_scheduler(decorated_add: Callable, started_scheduler_client) -> None: + """Test graph build.""" + wg = WorkGraph("test_scheduler") + add1 = wg.add_task(decorated_add, x=2, y=3) + add2 = wg.add_task(decorated_add, "add2", x=3, y=add1.outputs["result"]) + # use run to check if graph builder workgraph can be submit inside the engine + pk = get_scheduler() + wg.submit(to_scheduler=True, wait=True) + pk = get_scheduler() + report = get_workchain_report(orm.load(pk), "REPORT") + print("report: ", report) + assert add2.outputs["result"].value == 8 From 35aab9b9499251a6c104e9272beb21d56da7cd96 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Sun, 1 Sep 2024 08:15:38 +0200 Subject: [PATCH 10/14] handle failed task --- aiida_workgraph/engine/scheduler/scheduler.py | 47 ++++++++++++------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py index b7cd02fd..ffadc38f 100644 --- a/aiida_workgraph/engine/scheduler/scheduler.py +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -1172,25 +1172,36 @@ def run_tasks( kwargs.setdefault("metadata", {}) kwargs["metadata"].update({"call_link_label": name}) kwargs["parent_pid"] = pk - # transfer the args to kwargs - if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": - self.set_task_state_info(pk, name, "action", "") - self.report(f"Task {name} is created and paused.", pk) - process = create_and_pause_process( - self.runner, - executor, - kwargs, - state_msg="Paused through WorkGraph", + try: + # transfer the args to kwargs + if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": + self.set_task_state_info(pk, name, "action", "") + self.report(f"Task {name} is created and paused.", pk) + process = create_and_pause_process( + self.runner, + executor, + kwargs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(pk, name, "state", "CREATED") + process = process.node + else: + process = launch.submit(executor, **kwargs) + self.set_task_state_info(pk, name, "state", "RUNNING") + process.label = name + process.workgraph_pk = pk + self.set_task_state_info(pk, name, "process", process) + self.to_context(**{name: process}) + except Exception as e: + self.set_task_state_info(pk, name, "state", "FAILED") + # set child state to FAILED + self.set_tasks_state( + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", ) - self.set_task_state_info(pk, name, "state", "CREATED") - process = process.node - else: - process = launch.submit(executor, **kwargs) - self.set_task_state_info(pk, name, "state", "RUNNING") - process.label = name - process.workgraph_pk = pk - self.set_task_state_info(pk, name, "process", process) - self.to_context(**{name: process}) + self.report(f"Error in task {name}: {e}", pk) + self.continue_workgraph(pk) elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]: wg = self.run_executor(executor, [], kwargs, var_args, var_kwargs) wg.name = name From d2672637edce57c092e8ec3914b909e0c6eab754 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Mon, 2 Sep 2024 12:01:15 +0200 Subject: [PATCH 11/14] Scheduler add workgraph subsriber The scheduler will listen to the task from scheduler_queue --- aiida_workgraph/engine/scheduler/scheduler.py | 19 ++++++++++++++++- aiida_workgraph/utils/control.py | 11 ++++++++++ aiida_workgraph/workgraph.py | 21 +++++++++++++------ 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py index ffadc38f..5d91c489 100644 --- a/aiida_workgraph/engine/scheduler/scheduler.py +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -149,6 +149,7 @@ def load_instance_state( self._temp = {"awaitables": {}} self.set_logger(self.node.logger) + self.add_workgraph_subsriber() if self._awaitables: # For the "ascyncio.tasks.Task" awaitable, because there are only in-memory, @@ -512,7 +513,7 @@ def setup(self) -> None: # self.ctx._workgraph[pk]["_execution_count"] = {} # data not to be persisted, because they are not serializable self._temp = {"awaitables": {}} - # self.launch_workgraph(122305) + self.add_workgraph_subsriber() def launch_workgraph(self, pk: str) -> None: """Launch the workgraph.""" @@ -1633,6 +1634,22 @@ def message_receive( # Didn't match any known intents raise RuntimeError("Unknown intent") + def call_on_receive_workgraph_message(self, _comm, msg): + """Call on receive workgraph message.""" + # self.report(f"Received workgraph message: {msg}") + pk = int(msg) + # To avoid "DbNode is not persistent", we need to schedule the call + self._schedule_rpc(self.launch_workgraph, pk=pk) + return True + + def add_workgraph_subsriber(self) -> None: + """Add workgraph subscriber.""" + queue_name = "scheduler_queue" + # self.report(f"Add workgraph subscriber on queue: {queue_name}") + comm = self.runner.communicator._communicator + queue = comm.task_queue(queue_name, prefetch_count=1000) + queue.add_task_subscriber(self.callback) + def finalize_workgraph(self, pk: int) -> t.Optional[ExitCode]: """""" # expose outputs of the workgraph diff --git a/aiida_workgraph/utils/control.py b/aiida_workgraph/utils/control.py index 376f8fc3..8b54f439 100644 --- a/aiida_workgraph/utils/control.py +++ b/aiida_workgraph/utils/control.py @@ -17,6 +17,17 @@ def create_task_action( controller._communicator.rpc_send(pk, message) +def create_scheduler_action( + pk: int, +): + """Send workgraph task to scheduler.""" + + controller = get_manager().get_process_controller() + message = str(pk) + queue = controller._communicator.task_queue("scheduler_queue") + queue.task_send(message) + + def get_task_state_info(node, name: str, key: str) -> str: """Get task state info from base.extras.""" from aiida.orm.utils.serialize import deserialize_unsafe diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 341cdae0..647f8778 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -145,7 +145,7 @@ def submit( if self.process.process_state.value.upper() not in ["CREATED"]: raise ValueError(f"Process {self.process.pk} has already been submitted.") if to_scheduler: - self.continue_process_in_scheduler() + self.continue_process_in_scheduler(to_scheduler) else: self.continue_process() # as long as we submit the process, it is a new submission, we should set restart_process to None @@ -432,15 +432,24 @@ def continue_process(self): process_controller = get_manager().get_process_controller() process_controller.continue_process(self.pk) - def continue_process_in_scheduler(self): - """Ask the scheduler to pick up the process from the database and run it.""" - from aiida_workgraph.utils.control import create_task_action + def continue_process_in_scheduler(self, to_scheduler: Union[int, bool]) -> None: + """Ask the scheduler to pick up the process from the database and run it. + If to_scheduler is an integer, it will be used as the scheduler pk. + Otherwise, it will send the message to the queue, and the scheduler will pick it up. + """ + from aiida_workgraph.utils.control import ( + create_task_action, + create_scheduler_action, + ) from aiida_workgraph.engine.scheduler.client import get_scheduler import kiwipy try: - scheduler_pk = get_scheduler() - create_task_action(scheduler_pk, [self.pk], action="launch_workgraph") + if isinstance(to_scheduler, int): + scheduler_pk = get_scheduler() + create_task_action(scheduler_pk, [self.pk], action="launch_workgraph") + else: + create_scheduler_action(self.pk) except ValueError: print( """Scheduler is not running. From d35d63e748c350fd4ee3056074de1cb0101e280a Mon Sep 17 00:00:00 2001 From: superstar54 Date: Mon, 2 Sep 2024 16:53:17 +0200 Subject: [PATCH 12/14] Support multiple scheduler. 1) can run multiple runner (daemon) for the scheduler, each runner will listen to the `scheduler_queue`, and the prefetch_count is set to 1, thus each runner can only launch one Scheduler process. 2) The scheduler process listen to the `workgraph_queue` to launch WorkGraph 3) the scheduler recieve rpc call to launch WorkGrpah 4) user can submit workgraph to the workgraph queue, or select the shceduler to run it by pk --- aiida_workgraph/cli/cmd_scheduler.py | 32 +-- aiida_workgraph/engine/launch.py | 84 +------- aiida_workgraph/engine/override.py | 71 +++++++ aiida_workgraph/engine/scheduler/client.py | 106 +++++++++- aiida_workgraph/engine/scheduler/scheduler.py | 11 +- aiida_workgraph/utils/control.py | 22 +- aiida_workgraph/workgraph.py | 30 +-- docs/source/howto/scheduler.ipynb | 200 +++++++++++------- tests/test_scheduler.py | 3 +- 9 files changed, 343 insertions(+), 216 deletions(-) create mode 100644 aiida_workgraph/engine/override.py diff --git a/aiida_workgraph/cli/cmd_scheduler.py b/aiida_workgraph/cli/cmd_scheduler.py index d8c52d09..52cb54d3 100644 --- a/aiida_workgraph/cli/cmd_scheduler.py +++ b/aiida_workgraph/cli/cmd_scheduler.py @@ -1,27 +1,11 @@ from aiida_workgraph.cli.cmd_workgraph import workgraph import click -from pathlib import Path from aiida.cmdline.utils import decorators, echo +from aiida.cmdline.commands.cmd_daemon import validate_daemon_workers from aiida.cmdline.params import options from aiida_workgraph.engine.scheduler.client import get_scheduler_client import sys -REACT_PORT = "3000" - - -def get_package_root(): - """Returns the root directory of the package.""" - current_file = Path(__file__) - # Root directory of your package - return current_file.parent - - -def get_pid_file_path(): - """Get the path to the PID file in the desired directory.""" - from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER - - return AIIDA_CONFIG_FOLDER / "scheduler_processes.pid" - @workgraph.group("scheduler") def scheduler(): @@ -31,7 +15,7 @@ def scheduler(): @scheduler.command() def worker(): """Start the scheduler application.""" - from aiida_workgraph.engine.launch import start_scheduler_worker + from aiida_workgraph.engine.scheduler.client import start_scheduler_worker click.echo("Starting the scheduler worker...") @@ -40,17 +24,20 @@ def worker(): @scheduler.command() @click.option("--foreground", is_flag=True, help="Run in foreground.") +@click.argument("number", required=False, type=int, callback=validate_daemon_workers) @options.TIMEOUT(default=None, required=False, type=int) @decorators.with_dbenv() @decorators.requires_broker @decorators.check_circus_zmq_version -def start(foreground, timeout): +def start(foreground, number, timeout): """Start the scheduler application.""" + from aiida_workgraph.engine.scheduler.client import start_scheduler_process click.echo("Starting the scheduler process...") client = get_scheduler_client() - client.start_daemon(foreground=foreground) + client.start_daemon(number_workers=number, foreground=foreground, timeout=timeout) + start_scheduler_process(number) @scheduler.command() @@ -86,10 +73,11 @@ def stop(ctx, no_wait, all_profiles, timeout): @scheduler.command(hidden=True) @click.option("--foreground", is_flag=True, help="Run in foreground.") +@click.argument("number", required=False, type=int, callback=validate_daemon_workers) @decorators.with_dbenv() @decorators.requires_broker @decorators.check_circus_zmq_version -def start_circus(foreground): +def start_circus(foreground, number): """This will actually launch the circus daemon, either daemonized in the background or in the foreground. If run in the foreground all logs are redirected to stdout. @@ -97,7 +85,7 @@ def start_circus(foreground): .. note:: this should not be called directly from the commandline! """ - get_scheduler_client()._start_daemon(foreground=foreground) + get_scheduler_client()._start_daemon(number_workers=number, foreground=foreground) @scheduler.command() diff --git a/aiida_workgraph/engine/launch.py b/aiida_workgraph/engine/launch.py index 96324c75..25b8c1e8 100644 --- a/aiida_workgraph/engine/launch.py +++ b/aiida_workgraph/engine/launch.py @@ -27,6 +27,13 @@ LOGGER = AIIDA_LOGGER.getChild("engine.launch") +""" +Note: I modified the run_get_node and submit functions to include the parent_pid argument. +This is necessary for keeping track of the provenance of the processes. + +""" + + def run_get_node( process_class, *args, **kwargs ) -> tuple[dict[str, t.Any] | None, "ProcessNode"]: @@ -170,80 +177,3 @@ def submit( time.sleep(wait_interval) return node - - -def start_scheduler_worker(foreground: bool = False) -> None: - """Start a scheduler worker for the currently configured profile. - - :param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to - write to the scheduler log file. - """ - import asyncio - import signal - import sys - - from aiida.common.log import configure_logging - from aiida.manage import get_config_option, get_manager - from aiida_workgraph.engine.scheduler import WorkGraphScheduler - from aiida_workgraph.engine.scheduler.client import ( - get_scheduler_client, - get_scheduler, - ) - from aiida.engine.processes.launcher import ProcessLauncher - from aiida.engine import persistence - from plumpy.persistence import LoadSaveContext - from aiida.engine.daemon.worker import shutdown_worker - - daemon_client = get_scheduler_client() - configure_logging( - daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file - ) - - LOGGER.debug(f"sys.executable: {sys.executable}") - LOGGER.debug(f"sys.path: {sys.path}") - - try: - manager = get_manager() - # runner = manager.create_daemon_runner() - runner = manager.create_runner(broker_submit=True) - manager.set_runner(runner) - except Exception: - LOGGER.exception("daemon worker failed to start") - raise - - if isinstance(rlimit := get_config_option("daemon.recursion_limit"), int): - LOGGER.info("Setting maximum recursion limit of daemon worker to %s", rlimit) - sys.setrecursionlimit(rlimit) - - signals = (signal.SIGTERM, signal.SIGINT) - for s in signals: - # https://github.com/python/mypy/issues/12557 - runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) # type: ignore[misc] - - try: - running_scheduler = get_scheduler() - runner_loop = runner.loop - task_receiver = ProcessLauncher( - loop=runner_loop, - persister=manager.get_persister(), - load_context=LoadSaveContext(runner=runner), - loader=persistence.get_object_loader(), - ) - asyncio.run( - task_receiver._continue( - communicator=None, pid=running_scheduler, nowait=True - ) - ) - except ValueError: - print("Starting a new Scheduler") - process_inited = instantiate_process(runner, WorkGraphScheduler) - runner.loop.create_task(process_inited.step_until_terminated()) - - try: - LOGGER.info("Starting a daemon worker") - runner.start() - except SystemError as exception: - LOGGER.info("Received a SystemError: %s", exception) - runner.close() - - LOGGER.info("Daemon worker started") diff --git a/aiida_workgraph/engine/override.py b/aiida_workgraph/engine/override.py new file mode 100644 index 00000000..020be102 --- /dev/null +++ b/aiida_workgraph/engine/override.py @@ -0,0 +1,71 @@ +from plumpy.process_comms import RemoteProcessThreadController +from typing import Any, Optional + +""" +Note: I modified the the create_daemon_runner function and RemoteProcessThreadController +to include the queue_name argument. + +""" + + +def create_daemon_runner( + manager, queue_name: str = None, loop: Optional["asyncio.AbstractEventLoop"] = None +) -> "Runner": + """Create and return a new daemon runner. + This is used by workers when the daemon is running and in testing. + :param loop: the (optional) asyncio event loop to use + :return: a runner configured to work in the daemon configuration + """ + from plumpy.persistence import LoadSaveContext + from aiida.engine import persistence + from aiida.engine.processes.launcher import ProcessLauncher + from plumpy.communications import convert_to_comm + + runner = manager.create_runner(broker_submit=True, loop=loop) + runner_loop = runner.loop + # Listen for incoming launch requests + task_receiver = ProcessLauncher( + loop=runner_loop, + persister=manager.get_persister(), + load_context=LoadSaveContext(runner=runner), + loader=persistence.get_object_loader(), + ) + + def callback(_comm, msg): + print("Received message: {}".format(msg)) + import asyncio + + asyncio.run(task_receiver(_comm, msg)) + print("task_receiver._continue done") + return True + + assert runner.communicator is not None, "communicator not set for runner" + if queue_name is not None: + print("queue_name: {}".format(queue_name)) + queue = runner.communicator._communicator.task_queue( + queue_name, prefetch_count=1 + ) + # queue.add_task_subscriber(callback) + # important to convert the callback + converted = convert_to_comm(task_receiver, runner.communicator._loop) + queue.add_task_subscriber(converted) + else: + runner.communicator.add_task_subscriber(task_receiver) + return runner + + +class ControllerWithQueueName(RemoteProcessThreadController): + def __init__(self, queue_name: str, **kwargs): + super().__init__(**kwargs) + self.queue_name = queue_name + + def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]: + """ + Send a task to be performed using the communicator + + :param message: the task message + :param no_reply: if True, this call will be fire-and-forget, i.e. no return value + :return: the response from the remote side (if no_reply=False) + """ + queue = self._communicator.task_queue(self.queue_name) + return queue.task_send(message, no_reply=no_reply) diff --git a/aiida_workgraph/engine/scheduler/client.py b/aiida_workgraph/engine/scheduler/client.py index 1774eb64..1beaeac8 100644 --- a/aiida_workgraph/engine/scheduler/client.py +++ b/aiida_workgraph/engine/scheduler/client.py @@ -4,8 +4,11 @@ from aiida.common.exceptions import ConfigurationError import os from typing import Optional +from aiida.common.log import AIIDA_LOGGER +from typing import List WORKGRAPH_BIN = shutil.which("workgraph") +LOGGER = AIIDA_LOGGER.getChild("engine.launch") class SchedulerClient(DaemonClient): @@ -102,6 +105,7 @@ def cmd_start_daemon( self.profile.name, "scheduler", "start-circus", + str(number_workers), ] if foreground: @@ -114,7 +118,7 @@ def cmd_start_daemon_worker(self) -> list[str]: """Return the command to start a daemon worker process.""" return [self._workgraph_bin, "-p", self.profile.name, "scheduler", "worker"] - def _start_daemon(self, foreground: bool = False) -> None: + def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> None: """Start the daemon. .. warning:: This will daemonize the current process and put it in the background. It is most likely not what @@ -149,7 +153,7 @@ def _start_daemon(self, foreground: bool = False) -> None: { "cmd": " ".join(self.cmd_start_daemon_worker), "name": self.daemon_name, - "numprocesses": 1, + "numprocesses": number_workers, "virtualenv": self.virtualenv, "copy_env": True, "stdout_stream": { @@ -210,7 +214,7 @@ def get_scheduler_client(profile_name: Optional[str] = None) -> "SchedulerClient return SchedulerClient(profile) -def get_scheduler(): +def get_scheduler() -> List[int]: from aiida.orm import QueryBuilder from aiida_workgraph.engine.scheduler import WorkGraphScheduler @@ -224,7 +228,95 @@ def get_scheduler(): } qb.append(WorkGraphScheduler, filters=filters, project=projections, tag="process") results = qb.all() - if len(results) == 0: - raise ValueError("No scheduler found. Please start the scheduler first.") - scheduler_id = results[0][0] - return scheduler_id + pks = [r[0] for r in results] + return pks + + +def start_scheduler_worker(foreground: bool = False) -> None: + """Start a scheduler worker for the currently configured profile. + + :param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to + write to the scheduler log file. + """ + import asyncio + import signal + import sys + from aiida_workgraph.engine.scheduler.client import get_scheduler_client + from aiida_workgraph.engine.override import create_daemon_runner + + from aiida.common.log import configure_logging + from aiida.manage import get_config_option + from aiida.engine.daemon.worker import shutdown_worker + + daemon_client = get_scheduler_client() + configure_logging( + daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file + ) + + LOGGER.debug(f"sys.executable: {sys.executable}") + LOGGER.debug(f"sys.path: {sys.path}") + + try: + manager = get_manager() + runner = create_daemon_runner(manager, queue_name="scheduler_queue") + except Exception: + LOGGER.exception("daemon worker failed to start") + raise + + if isinstance(rlimit := get_config_option("daemon.recursion_limit"), int): + LOGGER.info("Setting maximum recursion limit of daemon worker to %s", rlimit) + sys.setrecursionlimit(rlimit) + + signals = (signal.SIGTERM, signal.SIGINT) + for s in signals: + # https://github.com/python/mypy/issues/12557 + runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) # type: ignore[misc] + + try: + LOGGER.info("Starting a daemon worker") + runner.start() + except SystemError as exception: + LOGGER.info("Received a SystemError: %s", exception) + runner.close() + + LOGGER.info("Daemon worker started") + + +def start_scheduler_process(number: int = 1) -> None: + """Start or restart the specified number of scheduler processes.""" + from aiida_workgraph.engine.scheduler import WorkGraphScheduler + from aiida_workgraph.engine.scheduler.client import get_scheduler + from aiida_workgraph.utils.control import create_scheduler_action + from aiida_workgraph.engine.utils import instantiate_process + + try: + schedulers: List[int] = get_scheduler() + existing_schedulers_count = len(schedulers) + print( + "Found {} existing scheduler(s): {}".format( + existing_schedulers_count, " ".join([str(pk) for pk in schedulers]) + ) + ) + + count = 0 + + # Restart existing schedulers if they exceed the number to start + for pk in schedulers[:number]: + create_scheduler_action(pk) + print(f"Scheduler with pk {pk} running.") + count += 1 + # not running + for pk in schedulers[number:]: + print(f"Scheduler with pk {pk} not running.") + + # Start new schedulers if more are needed + runner = get_manager().get_runner() + for i in range(count, number): + process_inited = instantiate_process(runner, WorkGraphScheduler) + process_inited.runner.persister.save_checkpoint(process_inited) + process_inited.close() + create_scheduler_action(process_inited.node.pk) + print(f"Scheduler with pk {process_inited.node.pk} running.") + + except Exception as e: + raise (f"An error occurred while starting schedulers: {e}") diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py index 5d91c489..53f77cfa 100644 --- a/aiida_workgraph/engine/scheduler/scheduler.py +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -778,15 +778,12 @@ def update_task_state( ) -> None: """Update task state when the task is finished.""" - print("update task state: ", pk, name) task = self.ctx._workgraph[pk]["_tasks"][name] # print(f"set task result: {name}") node = self.get_task_state_info(pk, name, "process") - print("node", node) if isinstance(node, orm.ProcessNode): # print(f"set task result: {name} process") state = node.process_state.value.upper() - print("state", state) if node.is_finished_ok: self.set_task_state_info(pk, task["name"], "state", state) if task["metadata"]["node_type"].upper() == "WORKGRAPH": @@ -1637,18 +1634,18 @@ def message_receive( def call_on_receive_workgraph_message(self, _comm, msg): """Call on receive workgraph message.""" # self.report(f"Received workgraph message: {msg}") - pk = int(msg) + pk = msg["args"]["pid"] # To avoid "DbNode is not persistent", we need to schedule the call self._schedule_rpc(self.launch_workgraph, pk=pk) return True def add_workgraph_subsriber(self) -> None: """Add workgraph subscriber.""" - queue_name = "scheduler_queue" - # self.report(f"Add workgraph subscriber on queue: {queue_name}") + queue_name = "workgraph_queue" + self.report(f"Add workgraph subscriber on queue: {queue_name}") comm = self.runner.communicator._communicator queue = comm.task_queue(queue_name, prefetch_count=1000) - queue.add_task_subscriber(self.callback) + queue.add_task_subscriber(self.call_on_receive_workgraph_message) def finalize_workgraph(self, pk: int) -> t.Optional[ExitCode]: """""" diff --git a/aiida_workgraph/utils/control.py b/aiida_workgraph/utils/control.py index 8b54f439..4fa50be6 100644 --- a/aiida_workgraph/utils/control.py +++ b/aiida_workgraph/utils/control.py @@ -1,6 +1,7 @@ from aiida.manage import get_manager from aiida import orm from aiida.engine.processes import control +from aiida_workgraph.engine.override import ControllerWithQueueName def create_task_action( @@ -22,10 +23,23 @@ def create_scheduler_action( ): """Send workgraph task to scheduler.""" - controller = get_manager().get_process_controller() - message = str(pk) - queue = controller._communicator.task_queue("scheduler_queue") - queue.task_send(message) + manager = get_manager() + controller = ControllerWithQueueName( + queue_name="scheduler_queue", communicator=manager.get_communicator() + ) + controller.continue_process(pk, nowait=False) + + +def create_workgraph_action( + pk: int, +): + """Send workgraph task to scheduler.""" + + manager = get_manager() + controller = ControllerWithQueueName( + queue_name="workgraph_queue", communicator=manager.get_communicator() + ) + controller.continue_process(pk, nowait=False) def get_task_state_info(node, name: str, key: str) -> str: diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 647f8778..e99a34f6 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -124,15 +124,6 @@ def submit( restart (bool): Restart the process, and reset the modified tasks, then only re-run the modified tasks. new (bool): Submit a new process. """ - from aiida_workgraph.engine.scheduler.client import get_scheduler - - if to_scheduler: - try: - get_scheduler() - except ValueError as e: - print(e) - return - # set task inputs if inputs is not None: for name, input in inputs.items(): @@ -439,27 +430,14 @@ def continue_process_in_scheduler(self, to_scheduler: Union[int, bool]) -> None: """ from aiida_workgraph.utils.control import ( create_task_action, - create_scheduler_action, + create_workgraph_action, ) - from aiida_workgraph.engine.scheduler.client import get_scheduler - import kiwipy try: - if isinstance(to_scheduler, int): - scheduler_pk = get_scheduler() - create_task_action(scheduler_pk, [self.pk], action="launch_workgraph") + if isinstance(to_scheduler, int) and not isinstance(to_scheduler, bool): + create_task_action(to_scheduler, [self.pk], action="launch_workgraph") else: - create_scheduler_action(self.pk) - except ValueError: - print( - """Scheduler is not running. -Please start the scheduler first with `aiida-workgraph scheduler start`""" - ) - except kiwipy.exceptions.UnroutableError: - print( - """Scheduler exists, but the daemon is not running. -Please start the scheduler first with `aiida-workgraph scheduler start`""" - ) + create_workgraph_action(self.pk) except Exception as e: print("""An unexpected error occurred:""", e) diff --git a/docs/source/howto/scheduler.ipynb b/docs/source/howto/scheduler.ipynb index ad046289..e88f2f50 100644 --- a/docs/source/howto/scheduler.ipynb +++ b/docs/source/howto/scheduler.ipynb @@ -6,42 +6,93 @@ "source": [ "# Scheduler\n", "\n", - "Start a scheduler daemon:\n", + "## Overview\n", + "\n", + "This documentation provides a guide on using the `aiida-workgraph` Scheduler to manage `WorkGraph` processes efficiently.\n", + "\n", + "### Background\n", + "\n", + "Traditional workflow processes, particularly in nested structures like `PwBandsWorkChain`, tend to create multiple Workflow processes in a waiting state, while only a few `CalcJob` processes run actively. This results in inefficient resource usage. The `WorkChain` structure makes it challenging to eliminate these waiting processes due to its encapsulated logic.\n", + "\n", + "In contrast, the `WorkGraph` offers a more clear task dependency and allow other process to run its tasks in a controllable way. In a scheduler, one only need create the `WorkGraph` process in the database, not run it via a daemon worker.\n", + "\n", + "### Process Comparison: `PwBands` Case\n", + "\n", + "- **Old Approach**: 300 Workflow processes (Bands, Relax, Base) + 100 CalcJob processes.\n", + "- **New Approach**: 1 Scheduler process + 100 CalcJob processes.\n", + "\n", + "This new approach significantly reduces the number of active processes and mitigates the risk of deadlocks.\n", + "\n", + "## Getting Started with the Scheduler\n", + "\n", + "### Starting the Scheduler\n", + "\n", + "To launch a scheduler daemon:\n", "\n", "```console\n", "workgraph scheduler start\n", "```\n", "\n", - "Check the status of the scheduler:\n", + "### Monitoring the Scheduler\n", + "\n", + "To check the current status of the scheduler:\n", "\n", "```console\n", "workgraph scheduler status\n", "```\n", "\n", - "Stop the scheduler:\n", + "### Stopping the Scheduler\n", + "\n", + "To stop the scheduler daemon:\n", "\n", "```console\n", "workgraph scheduler stop\n", "```\n", "\n", - "## Submit workgraph to the scheduler\n", - "Set `to_scheduler` to `True` when submitting a workgraph to the scheduler:\n", + "## Submitting WorkGraphs to the Scheduler\n", + "\n", + "To submit a WorkGraph to the scheduler, set the `to_scheduler` flag to `True`:\n", "\n", "```python\n", "wg.submit(to_scheduler=True)\n", - "```" + "```\n", + "\n", + "\n", + "### Using Multiple Schedulers\n", + "\n", + "For environments with a high volume of WorkGraphs, starting multiple schedulers can enhance throughput:\n", + "\n", + "```console\n", + "workgraph scheduler start 2\n", + "```\n", + "\n", + "WorkGraphs will be automatically distributed among available schedulers.\n", + "\n", + "#### Specifying a Scheduler\n", + "\n", + "To submit a WorkGraph to a specific scheduler using its primary key (`pk`):\n", + "\n", + "```python\n", + "wg.submit(to_scheduler=pk_scheduler)\n", + "```\n", + "\n", + "### Best Practices for Scheduler Usage\n", + "\n", + "While a single scheduler suffices for most use cases, scaling up the number of schedulers may be beneficial when significantly increasing the number of task workers (created by `verdi daemon start`). A general rule is to maintain a ratio of less than 5 workers per scheduler.\n", + "\n", + "## Example" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "WorkGraph process created, PK: 134971\n", + "WorkGraph process created, PK: 142617\n", "State of WorkGraph : FINISHED\n", "Result of add2 : 4\n" ] @@ -78,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -94,41 +145,41 @@ " viewBox=\"0.00 0.00 1036.43 720.06\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "\n", "\n", - "\n", + "\n", "\n", - "N134971\n", + "N142617\n", "\n", - "WorkGraph<test_scheduler> (134971)\n", + "WorkGraph<test_scheduler> (142617)\n", "State: finished\n", "Exit Code: 0\n", "\n", - "\n", + "\n", "\n", - "N134974\n", + "N142620\n", "\n", - "ArithmeticAddCalculation (134974)\n", + "ArithmeticAddCalculation (142620)\n", "State: finished\n", "Exit Code: 0\n", "\n", - "\n", - "\n", - "N134971->N134974\n", + "\n", + "\n", + "N142617->N142620\n", "\n", "\n", "CALL_CALC\n", "add1\n", "\n", - "\n", + "\n", "\n", - "N134979\n", + "N142625\n", "\n", - "ArithmeticAddCalculation (134979)\n", + "ArithmeticAddCalculation (142625)\n", "State: finished\n", "Exit Code: 0\n", "\n", - "\n", - "\n", - "N134971->N134979\n", + "\n", + "\n", + "N142617->N142625\n", "\n", "\n", "CALL_CALC\n", @@ -141,111 +192,111 @@ "InstalledCode (37)\n", "add@localhost\n", "\n", - "\n", + "\n", "\n", - "N37->N134971\n", + "N37->N142617\n", "\n", "\n", "INPUT_WORK\n", - "wg__tasks__add2__properties__code__value\n", + "wg__tasks__add1__properties__code__value\n", "\n", - "\n", + "\n", "\n", - "N37->N134971\n", + "N37->N142617\n", "\n", "\n", "INPUT_WORK\n", - "wg__tasks__add1__properties__code__value\n", + "wg__tasks__add2__properties__code__value\n", "\n", - "\n", + "\n", "\n", - "N134975\n", + "N142621\n", "\n", - "RemoteData (134975)\n", + "RemoteData (142621)\n", "@localhost\n", "\n", - "\n", - "\n", - "N134974->N134975\n", + "\n", + "\n", + "N142620->N142621\n", "\n", "\n", "CREATE\n", "remote_folder\n", "\n", - "\n", + "\n", "\n", - "N134976\n", + "N142622\n", "\n", - "FolderData (134976)\n", + "FolderData (142622)\n", "\n", - "\n", - "\n", - "N134974->N134976\n", + "\n", + "\n", + "N142620->N142622\n", "\n", "\n", "CREATE\n", "retrieved\n", "\n", - "\n", + "\n", "\n", - "N134977\n", + "N142623\n", "\n", - "Int (134977)\n", + "Int (142623)\n", "\n", - "\n", - "\n", - "N134974->N134977\n", + "\n", + "\n", + "N142620->N142623\n", "\n", "\n", "CREATE\n", "sum\n", "\n", - "\n", - "\n", - "N134977->N134979\n", + "\n", + "\n", + "N142623->N142625\n", "\n", "\n", "INPUT_CALC\n", "y\n", "\n", - "\n", + "\n", "\n", - "N134980\n", + "N142626\n", "\n", - "RemoteData (134980)\n", + "RemoteData (142626)\n", "@localhost\n", "\n", - "\n", - "\n", - "N134979->N134980\n", + "\n", + "\n", + "N142625->N142626\n", "\n", "\n", "CREATE\n", "remote_folder\n", "\n", - "\n", + "\n", "\n", - "N134981\n", + "N142627\n", "\n", - "FolderData (134981)\n", + "FolderData (142627)\n", "\n", - "\n", - "\n", - "N134979->N134981\n", + "\n", + "\n", + "N142625->N142627\n", "\n", "\n", "CREATE\n", "retrieved\n", "\n", - "\n", + "\n", "\n", - "N134982\n", + "N142628\n", "\n", - "Int (134982)\n", + "Int (142628)\n", "\n", - "\n", + "\n", "\n", - "N134979->N134982\n", + "N142625->N142628\n", "\n", "\n", "CREATE\n", @@ -255,10 +306,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 6, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -272,7 +323,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Conclusion" + "\n", + "## Checkpointing\n", + "\n", + "The Scheduler checkpoints its status to the database whenever a WorkGraph is updated, ensuring that the Scheduler can recover its state in case of a crash or restart. This feature is particularly useful for long-running WorkGraphs.\n", + "\n", + "## Conclusion\n", + "\n", + "The Scheduler offers a streamlined approach to managing complex workflows, significantly reducing active process counts and improving resource efficiency." ] } ], diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index f51a1320..c6308551 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -6,7 +6,7 @@ from aiida import orm -@pytest.skip("Skip for now") +@pytest.mark.skip("Skip for now") @pytest.mark.usefixtures("started_daemon_client") def test_scheduler(decorated_add: Callable, started_scheduler_client) -> None: """Test graph build.""" @@ -14,7 +14,6 @@ def test_scheduler(decorated_add: Callable, started_scheduler_client) -> None: add1 = wg.add_task(decorated_add, x=2, y=3) add2 = wg.add_task(decorated_add, "add2", x=3, y=add1.outputs["result"]) # use run to check if graph builder workgraph can be submit inside the engine - pk = get_scheduler() wg.submit(to_scheduler=True, wait=True) pk = get_scheduler() report = get_workchain_report(orm.load(pk), "REPORT") From 5a29c798b7cecabef2cefc19fc59e798c334fb4b Mon Sep 17 00:00:00 2001 From: superstar54 Date: Mon, 2 Sep 2024 21:40:15 +0200 Subject: [PATCH 13/14] Web: add setting for scheduler --- aiida_workgraph/engine/scheduler/client.py | 6 +- aiida_workgraph/engine/scheduler/scheduler.py | 2 + aiida_workgraph/web/backend/app/api.py | 2 + aiida_workgraph/web/backend/app/daemon.py | 12 +- aiida_workgraph/web/backend/app/scheduler.py | 127 ++++++++++++++++ .../web/frontend/src/components/Settings.js | 140 ++++++++++++------ 6 files changed, 237 insertions(+), 52 deletions(-) create mode 100644 aiida_workgraph/web/backend/app/scheduler.py diff --git a/aiida_workgraph/engine/scheduler/client.py b/aiida_workgraph/engine/scheduler/client.py index 1beaeac8..4f5a8212 100644 --- a/aiida_workgraph/engine/scheduler/client.py +++ b/aiida_workgraph/engine/scheduler/client.py @@ -302,8 +302,10 @@ def start_scheduler_process(number: int = 1) -> None: # Restart existing schedulers if they exceed the number to start for pk in schedulers[:number]: - create_scheduler_action(pk) - print(f"Scheduler with pk {pk} running.") + # When the runner stop, the runner does not ack back to rmq, + # so the msg is still in the queue, and the msg is not acked, + # we don't need send the msg to continue again + print(f"Scheduler with pk {pk} restart and running.") count += 1 # not running for pk in schedulers[number:]: diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py index 53f77cfa..36177fa4 100644 --- a/aiida_workgraph/engine/scheduler/scheduler.py +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -576,6 +576,8 @@ def read_wgdata_from_base(self, pk: int) -> t.Dict[str, t.Any]: if isinstance(prop["value"], PickledLocalFunction): prop["value"] = prop["value"].value wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) + wgdata["context"] = deserialize_unsafe(wgdata["context"]) + return wgdata, node def update_workgraph_from_base(self, pk: int) -> None: diff --git a/aiida_workgraph/web/backend/app/api.py b/aiida_workgraph/web/backend/app/api.py index d05b527b..4bf69c91 100644 --- a/aiida_workgraph/web/backend/app/api.py +++ b/aiida_workgraph/web/backend/app/api.py @@ -2,6 +2,7 @@ from fastapi.middleware.cors import CORSMiddleware from aiida.manage import manager from aiida_workgraph.web.backend.app.daemon import router as daemon_router +from aiida_workgraph.web.backend.app.scheduler import router as scheduler_router from aiida_workgraph.web.backend.app.workgraph import router as workgraph_router from aiida_workgraph.web.backend.app.datanode import router as datanode_router from fastapi.staticfiles import StaticFiles @@ -47,6 +48,7 @@ async def read_root() -> dict: app.include_router(workgraph_router) app.include_router(datanode_router) app.include_router(daemon_router) +app.include_router(scheduler_router) @app.get("/debug") diff --git a/aiida_workgraph/web/backend/app/daemon.py b/aiida_workgraph/web/backend/app/daemon.py index af069d7c..caa22cf4 100644 --- a/aiida_workgraph/web/backend/app/daemon.py +++ b/aiida_workgraph/web/backend/app/daemon.py @@ -22,7 +22,7 @@ class DaemonStatusModel(BaseModel): ) -@router.get("/api/daemon/status", response_model=DaemonStatusModel) +@router.get("/api/daemon/task/status", response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_status() -> DaemonStatusModel: """Return the daemon status.""" @@ -36,7 +36,7 @@ async def get_daemon_status() -> DaemonStatusModel: return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) -@router.get("/api/daemon/worker") +@router.get("/api/daemon/task/worker") @with_dbenv() async def get_daemon_worker(): """Return the daemon status.""" @@ -50,7 +50,7 @@ async def get_daemon_worker(): return response["info"] -@router.post("/api/daemon/start", response_model=DaemonStatusModel) +@router.post("/api/daemon/task/start", response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_start() -> DaemonStatusModel: """Start the daemon.""" @@ -69,7 +69,7 @@ async def get_daemon_start() -> DaemonStatusModel: return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) -@router.post("/api/daemon/stop", response_model=DaemonStatusModel) +@router.post("/api/daemon/task/stop", response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_stop() -> DaemonStatusModel: """Stop the daemon.""" @@ -86,7 +86,7 @@ async def get_daemon_stop() -> DaemonStatusModel: return DaemonStatusModel(running=False, num_workers=None) -@router.post("/api/daemon/increase", response_model=DaemonStatusModel) +@router.post("/api/daemon/task/increase", response_model=DaemonStatusModel) @with_dbenv() async def increase_daemon_worker() -> DaemonStatusModel: """increase the daemon worker.""" @@ -103,7 +103,7 @@ async def increase_daemon_worker() -> DaemonStatusModel: return DaemonStatusModel(running=False, num_workers=None) -@router.post("/api/daemon/decrease", response_model=DaemonStatusModel) +@router.post("/api/daemon/task/decrease", response_model=DaemonStatusModel) @with_dbenv() async def decrease_daemon_worker() -> DaemonStatusModel: """decrease the daemon worker.""" diff --git a/aiida_workgraph/web/backend/app/scheduler.py b/aiida_workgraph/web/backend/app/scheduler.py new file mode 100644 index 00000000..c4110402 --- /dev/null +++ b/aiida_workgraph/web/backend/app/scheduler.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +"""Declaration of FastAPI router for daemon endpoints.""" +from __future__ import annotations + +import typing as t + +from aiida.cmdline.utils.decorators import with_dbenv +from aiida.engine.daemon.client import DaemonException +from aiida_workgraph.engine.scheduler.client import get_scheduler_client +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field +from aiida_workgraph.engine.scheduler.client import start_scheduler_process + + +router = APIRouter() + + +class DaemonStatusModel(BaseModel): + """Response model for daemon status.""" + + running: bool = Field(description="Whether the daemon is running or not.") + num_workers: t.Optional[int] = Field( + description="The number of workers if the daemon is running." + ) + + +@router.get("/api/daemon/scheduler/status", response_model=DaemonStatusModel) +@with_dbenv() +async def get_daemon_status() -> DaemonStatusModel: + """Return the daemon status.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + return DaemonStatusModel(running=False, num_workers=None) + + response = client.get_numprocesses() + + return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) + + +@router.get("/api/daemon/scheduler/worker") +@with_dbenv() +async def get_daemon_worker(): + """Return the daemon status.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + return {} + + response = client.get_worker_info() + + return response["info"] + + +@router.post("/api/daemon/scheduler/start", response_model=DaemonStatusModel) +@with_dbenv() +async def get_daemon_start() -> DaemonStatusModel: + """Start the daemon.""" + client = get_scheduler_client() + + if client.is_daemon_running: + raise HTTPException(status_code=400, detail="The daemon is already running.") + + try: + client.start_daemon() + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + response = client.get_numprocesses() + start_scheduler_process() + + return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) + + +@router.post("/api/daemon/scheduler/stop", response_model=DaemonStatusModel) +@with_dbenv() +async def get_daemon_stop() -> DaemonStatusModel: + """Stop the daemon.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + raise HTTPException(status_code=400, detail="The daemon is not running.") + + try: + client.stop_daemon() + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + return DaemonStatusModel(running=False, num_workers=None) + + +@router.post("/api/daemon/scheduler/increase", response_model=DaemonStatusModel) +@with_dbenv() +async def increase_daemon_worker() -> DaemonStatusModel: + """increase the daemon worker.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + raise HTTPException(status_code=400, detail="The daemon is not running.") + + try: + client.increase_workers(1) + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + response = client.get_numprocesses() + print(response) + start_scheduler_process(response["numprocesses"]) + + return DaemonStatusModel(running=False, num_workers=None) + + +@router.post("/api/daemon/scheduler/decrease", response_model=DaemonStatusModel) +@with_dbenv() +async def decrease_daemon_worker() -> DaemonStatusModel: + """decrease the daemon worker.""" + client = get_scheduler_client() + + if not client.is_daemon_running: + raise HTTPException(status_code=400, detail="The daemon is not running.") + + try: + client.decrease_workers(1) + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + return DaemonStatusModel(running=False, num_workers=None) diff --git a/aiida_workgraph/web/frontend/src/components/Settings.js b/aiida_workgraph/web/frontend/src/components/Settings.js index 23310618..50a89fee 100644 --- a/aiida_workgraph/web/frontend/src/components/Settings.js +++ b/aiida_workgraph/web/frontend/src/components/Settings.js @@ -3,79 +3,131 @@ import { ToastContainer, toast } from 'react-toastify'; import 'react-toastify/dist/ReactToastify.css'; function Settings() { - const [workers, setWorkers] = useState([]); + const [taskWorkers, setTaskWorkers] = useState([]); + const [schedulerWorkers, setSchedulerWorkers] = useState([]); - const fetchWorkers = () => { - fetch('http://localhost:8000/api/daemon/worker') + // Fetching task workers + const fetchTaskWorkers = () => { + fetch('http://localhost:8000/api/daemon/task/worker') .then(response => response.json()) - .then(data => setWorkers(Object.values(data))) - .catch(error => console.error('Failed to fetch workers:', error)); + .then(data => setTaskWorkers(Object.values(data))) + .catch(error => console.error('Failed to fetch task workers:', error)); + }; + + // Fetching scheduler workers + const fetchSchedulerWorkers = () => { + fetch('http://localhost:8000/api/daemon/scheduler/worker') + .then(response => response.json()) + .then(data => setSchedulerWorkers(Object.values(data))) + .catch(error => console.error('Failed to fetch scheduler workers:', error)); }; useEffect(() => { - fetchWorkers(); - const interval = setInterval(fetchWorkers, 1000); // Poll every 5 seconds - return () => clearInterval(interval); // Clear interval on component unmount + fetchTaskWorkers(); + fetchSchedulerWorkers(); + const taskInterval = setInterval(fetchTaskWorkers, 1000); + const schedulerInterval = setInterval(fetchSchedulerWorkers, 1000); + return () => { + clearInterval(taskInterval); + clearInterval(schedulerInterval); + }; // Clear intervals on component unmount }, []); - const handleDaemonControl = (action) => { - fetch(`http://localhost:8000/api/daemon/${action}`, { method: 'POST' }) + const handleDaemonControl = (daemonType, action) => { + fetch(`http://localhost:8000/api/daemon/${daemonType}/${action}`, { method: 'POST' }) .then(response => { if (!response.ok) { - throw new Error(`Daemon operation failed: ${response.statusText}`); + throw new Error(`${daemonType} daemon operation failed: ${response.statusText}`); } return response.json(); }) - .then(data => { - toast.success(`Daemon ${action}ed successfully`); - fetchWorkers(); + .then(() => { + toast.success(`${daemonType} daemon ${action}ed successfully`); + if (daemonType === 'task') { + fetchTaskWorkers(); + } else { + fetchSchedulerWorkers(); + } }) .catch(error => toast.error(error.message)); }; - const adjustWorkers = (action) => { - fetch(`http://localhost:8000/api/daemon/${action}`, { method: 'POST' }) + const adjustWorkers = (daemonType, action) => { + fetch(`http://localhost:8000/api/daemon/${daemonType}/${action}`, { method: 'POST' }) .then(response => { if (!response.ok) { - throw new Error(`Failed to ${action} workers: ${response.statusText}`); + throw new Error(`Failed to ${action} workers for ${daemonType}: ${response.statusText}`); } return response.json(); }) - .then(data => { - toast.success(`Workers ${action}ed successfully`); - fetchWorkers(); // Refetch workers after adjusting + .then(() => { + toast.success(`${daemonType} Workers ${action}ed successfully`); + if (daemonType === 'task') { + fetchTaskWorkers(); + } else { + fetchSchedulerWorkers(); + } }) .catch(error => toast.error(error.message)); }; return (
-

Daemon Control

- - - - - - - - - - - {workers.map(worker => ( - - - - - +
+

Task Daemon Control

+ + + + +
PIDMemory %CPU %Started
{worker.pid}{worker.mem}{worker.cpu}{new Date(worker.started * 1000).toLocaleString()}
+ + + + + + + + + + {taskWorkers.map(worker => ( + + + + + + + ))} + +
PIDMemory %CPU %Started
{worker.pid}{worker.mem}{worker.cpu}{new Date(worker.started * 1000).toLocaleString()}
+
+
+

Scheduler Daemon Control

+ + + + + + + + + + + - ))} - -
PIDMemory %CPU %Started
- - - - + + + {schedulerWorkers.map(worker => ( + + {worker.pid} + {worker.mem} + {worker.cpu} + {new Date(worker.started * 1000).toLocaleString()} + + ))} + + +
); } From eaa2f7e09076d669884db1f83288bba60953c867 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 3 Sep 2024 17:39:51 +0200 Subject: [PATCH 14/14] update workgraph group outputs, and broadcast_workgraph_state --- aiida_workgraph/engine/scheduler/scheduler.py | 73 +++++++++++++++---- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/aiida_workgraph/engine/scheduler/scheduler.py b/aiida_workgraph/engine/scheduler/scheduler.py index 36177fa4..77ba9ca9 100644 --- a/aiida_workgraph/engine/scheduler/scheduler.py +++ b/aiida_workgraph/engine/scheduler/scheduler.py @@ -513,7 +513,6 @@ def setup(self) -> None: # self.ctx._workgraph[pk]["_execution_count"] = {} # data not to be persisted, because they are not serializable self._temp = {"awaitables": {}} - self.add_workgraph_subsriber() def launch_workgraph(self, pk: str) -> None: """Launch the workgraph.""" @@ -741,9 +740,6 @@ def continue_workgraph(self, pk: int) -> None: is_finished, _ = self.is_workgraph_finished(pk) if is_finished: self.finalize_workgraph(pk) - self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL) - self.ctx._workgraph[pk]["_node"].set_exit_status(0) - self.ctx._workgraph[pk]["_node"].seal() # remove the workgraph from the context del self.ctx._workgraph[pk] self.ctx.launched_workgraphs.remove(pk) @@ -1203,6 +1199,8 @@ def run_tasks( self.report(f"Error in task {name}: {e}", pk) self.continue_workgraph(pk) elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]: + from aiida_workgraph.utils.control import create_workgraph_action + wg = self.run_executor(executor, [], kwargs, var_args, var_kwargs) wg.name = name wg.group_outputs = self.ctx._workgraph[pk]["_tasks"][name]["metadata"][ @@ -1212,7 +1210,7 @@ def run_tasks( try: wg.save(metadata=metadata, parent_pid=pk) process = wg.process - self.launch_workgraph(process.pk) + create_workgraph_action(process.pk) process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.set_task_state_info(pk, name, "state", "RUNNING") @@ -1685,14 +1683,61 @@ def finalize_workgraph(self, pk: int) -> t.Optional[ExitCode]: names[1] ], ) - self.out_many(group_outputs) # output the new data - self.out("new_data", self.ctx._workgraph[pk]["_new_data"]) - self.out( - "execution_count", - orm.Int(self.ctx._workgraph[pk]["_execution_count"]).store(), - ) + if self.ctx._workgraph[pk]["_new_data"]: + group_outputs["new_data"] = self.ctx._workgraph[pk]["_new_data"] + group_outputs["execution_count"] = orm.Int( + self.ctx._workgraph[pk]["_execution_count"] + ).store() + self.update_workgraph_output(pk, group_outputs) self.report("Finalize workgraph.", pk) - for _, task in self.ctx._workgraph[pk]["_tasks"].items(): - if self.get_task_state_info(pk, task["name"], "state") == "FAILED": - return self.exit_codes.TASK_FAILED + self.report(f"Finalize workgraph {pk}.") + for name, task in self.ctx._workgraph[pk]["_tasks"].items(): + if self.get_task_state_info(pk, name, "state") == "FAILED": + self.report(f"Task {name} failed.", pk) + self.ctx._workgraph[pk]["_node"].set_exit_status(302) + self.ctx._workgraph[pk]["_node"].set_exit_message( + "Some of the tasks failed." + ) + self.broadcast_workgraph_state(pk, "excepted") + self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL) + self.ctx._workgraph[pk]["_node"].seal() + return + # send a broadcast message to announce the workgraph is finished + self.broadcast_workgraph_state(pk, "finished") + self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL) + self.ctx._workgraph[pk]["_node"].seal() + + def update_workgraph_output(self, pk, outputs: dict) -> None: + """Update the workgraph output.""" + from aiida.common.links import LinkType + + node = self.ctx._workgraph[pk]["_node"] + + def add_link(node, outputs): + for link_label, output in outputs.items(): + if isinstance(output, dict): + add_link(node, output) + if isinstance(node, orm.CalculationNode): + output.base.links.add_incoming(node, LinkType.CREATE, link_label) + elif isinstance(node, orm.WorkflowNode): + output.base.links.add_incoming(node, LinkType.RETURN, link_label) + + add_link(node, outputs) + + def broadcast_workgraph_state(self, pk: int, state: str) -> None: + """Workgraph on entered.""" + from aio_pika.exceptions import ConnectionClosed + + from_label = None + subject = f"state_changed.{from_label}.{state}" + try: + self._communicator.broadcast_send(body=None, sender=pk, subject=subject) + except ConnectionClosed: + message = "Process<%s>: no connection available to broadcast state change from %s to %s" + self.logger.warning(message, pk, from_label, state) + except kiwipy.TimeoutError: + message = ( + "Process<%s>: sending broadcast of state change from %s to %s timed out" + ) + self.logger.warning(message, pk, from_label, state)