From 11c078b877e15f2a675c185aa6054e94477294e3 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Thu, 18 Apr 2024 17:40:19 +0200 Subject: [PATCH] support building node from a callable (#46) One can build a node from a callable or the path of the callable. --- .github/workflows/python-publish.yml | 5 ++- aiida_worktree/__init__.py | 2 +- aiida_worktree/decorator.py | 55 ++++++++++++++++++------- docs/source/concept/node.ipynb | 17 +++++--- docs/source/howto/append_worktree.ipynb | 2 +- docs/source/howto/protocol.ipynb | 4 +- docs/source/howto/queue.ipynb | 4 +- docs/source/quick_start.ipynb | 6 +-- docs/source/tutorial/eos.ipynb | 4 +- docs/source/tutorial/qe.ipynb | 19 ++++----- tests/conftest.py | 6 +-- tests/test_build_node.py | 47 +++++++++++++++++++++ tests/test_engine.py | 2 +- tests/test_protocol.py | 4 +- tests/test_qe.py | 5 ++- 15 files changed, 127 insertions(+), 55 deletions(-) create mode 100644 tests/test_build_node.py diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 28264b18..1c9e7a8c 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -30,11 +30,14 @@ jobs: run: | python -m pip install --upgrade pip pip install build - - name: Build package + - name: Build web frontend package and widget run: npm install npm run build python -m build + cd aiida_worktree/widget/ + npm install + npm run build - name: Publish package uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 with: diff --git a/aiida_worktree/__init__.py b/aiida_worktree/__init__.py index 71654339..983093f2 100644 --- a/aiida_worktree/__init__.py +++ b/aiida_worktree/__init__.py @@ -3,6 +3,6 @@ from .decorator import node, build_node -__version__ = "0.1.6" +__version__ = "0.1.8" __all__ = ["WorkTree", "Node", "node", "build_node"] diff --git a/aiida_worktree/decorator.py b/aiida_worktree/decorator.py index afc42896..42455295 100644 --- a/aiida_worktree/decorator.py +++ b/aiida_worktree/decorator.py @@ -30,27 +30,27 @@ def add_input_recursive(inputs, port, prefix=None): return inputs -def build_node(ndata): +def build_node(executor, outputs=None): + """Build node from executor.""" from aiida_worktree.worktree import WorkTree - if isinstance(ndata, WorkTree): - return build_node_from_worktree(ndata) - elif "path" in ndata: - return build_node_from_AiiDA(ndata) + if isinstance(executor, WorkTree): + return build_node_from_worktree(executor) + elif isinstance(executor, str): + ( + path, + executor_name, + ) = executor.rsplit(".", 1) + executor, _ = get_executor({"path": path, "name": executor_name}) + if callable(executor): + return build_node_from_callable(executor, outputs=outputs) -def build_node_from_AiiDA(ndata): - """Register a node from a AiiDA component. - For example: CalcJob, WorkChain, CalcFunction, WorkFunction.""" - from aiida_worktree.node import Node +def build_node_from_callable(executor, outputs=None): + """Build node from a callable object.""" import inspect - path, executor_name, = ndata.pop( - "path" - ).rsplit(".", 1) - ndata["executor"] = {"path": path, "name": executor_name} - executor, type = get_executor(ndata["executor"]) - # print(executor) + ndata = {} if inspect.isfunction(executor): # calcfunction and workfunction if getattr(executor, "node_class", False): @@ -58,17 +58,40 @@ def build_node_from_AiiDA(ndata): ndata["node_type"] = "calcfunction" elif executor.node_class is WorkFunctionNode: ndata["node_type"] = "workfunction" + ndata["executor"] = executor + return build_node_from_AiiDA(ndata) else: ndata["node_type"] = "normal" + ndata["executor"] = executor + return build_node_from_function(executor, outputs=outputs) else: if issubclass(executor, CalcJob): ndata["node_type"] = "calcjob" + ndata["executor"] = executor + return build_node_from_AiiDA(ndata) elif issubclass(executor, WorkChain): ndata["node_type"] = "workchain" + ndata["executor"] = executor + return build_node_from_AiiDA(ndata) else: ndata["node_type"] = "normal" + ndata["executor"] = executor + + +def build_node_from_function(executor, outputs=None): + """Build node from function.""" + return NodeDecoratorCollection.decorator_node(outputs=outputs)(executor).node + + +def build_node_from_AiiDA(ndata): + """Register a node from a AiiDA component. + For example: CalcJob, WorkChain, CalcFunction, WorkFunction.""" + from aiida_worktree.node import Node + + # print(executor) inputs = [] outputs = [] + executor = ndata["executor"] spec = executor.spec() for _key, port in spec.inputs.ports.items(): add_input_recursive(inputs, port) @@ -89,7 +112,7 @@ def build_node_from_AiiDA(ndata): ndata["kwargs"] = kwargs ndata["inputs"] = inputs ndata["outputs"] = outputs - ndata["identifier"] = ndata.pop("identifier", ndata["executor"]["name"]) + ndata["identifier"] = ndata.pop("identifier", ndata["executor"].__name__) # TODO In order to reload the WorkTree from process, "is_pickle" should be True # so I pickled the function here, but this is not necessary # we need to update the node_graph to support the path and name of the function diff --git a/docs/source/concept/node.ipynb b/docs/source/concept/node.ipynb index 53509323..962206a1 100644 --- a/docs/source/concept/node.ipynb +++ b/docs/source/concept/node.ipynb @@ -127,19 +127,24 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Build from AiiDA\n", - "One can build a node from an already existing AiiDA component: `calcfunction`, `workfunction`, `calcjob`, `Workchain` with the `build_node` function." + "## Build from Callable\n", + "\n", + "One can build a node from an already existing Python function, or AiiDA component: `calcfunction`, `workfunction`, `calcjob`, `Workchain` with the `build_node` function." ] }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from aiida_worktree import build_node\n", - "ndata = {\"path\": \"aiida.calculations.arithmetic.add.ArithmeticAddCalculation\"}\n", - "AddNode = build_node(ndata)" + "\n", + "from scipy.linalg import norm\n", + "NormNode = build_node(norm)\n", + "\n", + "from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n", + "AddNode = build_node(ArithmeticAddCalculation)\n" ] }, { @@ -151,7 +156,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 6, "metadata": {}, "outputs": [ { diff --git a/docs/source/howto/append_worktree.ipynb b/docs/source/howto/append_worktree.ipynb index 3841a28d..a61bb7a5 100644 --- a/docs/source/howto/append_worktree.ipynb +++ b/docs/source/howto/append_worktree.ipynb @@ -151,7 +151,7 @@ "source": [ "from aiida_worktree import WorkTree, build_node\n", "\n", - "PwRelaxChainNode = build_node({'path': 'aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain'})\n", + "PwRelaxChainNode = build_node('aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain')\n", "\n", "wt = WorkTree('Electronic Structure')\n", "relax_node = wt.nodes.new(PwRelaxChainNode, name='relax')\n", diff --git a/docs/source/howto/protocol.ipynb b/docs/source/howto/protocol.ipynb index 96b31555..1f491760 100644 --- a/docs/source/howto/protocol.ipynb +++ b/docs/source/howto/protocol.ipynb @@ -102,9 +102,7 @@ "from pprint import pprint\n", "\n", "# register node\n", - "pw_relax_node = build_node(\n", - " {\"path\": \"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\"}\n", - ")\n", + "pw_relax_node = build_node(\"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\")\n", "code = orm.load_code(\"pw-7.2@localhost\")\n", "wt = WorkTree(\"test_pw_relax\")\n", "structure_si = orm.StructureData(ase=bulk(\"Si\"))\n", diff --git a/docs/source/howto/queue.ipynb b/docs/source/howto/queue.ipynb index 9328e220..5ea794f8 100644 --- a/docs/source/howto/queue.ipynb +++ b/docs/source/howto/queue.ipynb @@ -83,9 +83,7 @@ "from aiida.orm import Int\n", "\n", "# Use the calcjob: ArithmeticAddCalculation\n", - "arithmetic_add = build_node(\n", - " {\"path\": \"aiida.calculations.arithmetic.add.ArithmeticAddCalculation\"}\n", - ")\n", + "arithmetic_add = build_node(\"aiida.calculations.arithmetic.add.ArithmeticAddCalculation\")\n", "code = load_code(\"add@localhost\")\n", "\n", "wt = WorkTree(\"test_max_number_jobs\")\n", diff --git a/docs/source/quick_start.ipynb b/docs/source/quick_start.ipynb index b0d0afe2..a1d2ed13 100644 --- a/docs/source/quick_start.ipynb +++ b/docs/source/quick_start.ipynb @@ -459,8 +459,8 @@ "outputs": [], "source": [ "from aiida_worktree import build_node\n", - "ndata = {\"path\": \"aiida.calculations.arithmetic.add.ArithmeticAddCalculation\"}\n", - "ArithmeticAddCalculationNode = build_node(ndata)" + "from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n", + "ArithmeticAddCalculationNode = build_node(ArithmeticAddCalculation)" ] }, { @@ -1178,7 +1178,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.10.0" }, "vscode": { "interpreter": { diff --git a/docs/source/tutorial/eos.ipynb b/docs/source/tutorial/eos.ipynb index a683115f..f990febf 100644 --- a/docs/source/tutorial/eos.ipynb +++ b/docs/source/tutorial/eos.ipynb @@ -43,7 +43,7 @@ " \"\"\"Run the scf calculation for each structure.\"\"\"\n", " from aiida_worktree import WorkTree, build_node\n", " # register PwCalculation calcjob as a node class\n", - " PwCalculation = build_node({\"path\": \"aiida_quantumespresso.calculations.pw.PwCalculation\"})\n", + " PwCalculation = build_node(\"aiida_quantumespresso.calculations.pw.PwCalculation\")\n", " wt = WorkTree()\n", " for key, structure in structures.items():\n", " pw1 = wt.nodes.new(PwCalculation, name=f\"pw1_{key}\", structure=structure)\n", @@ -348,7 +348,7 @@ "from aiida_worktree import WorkTree, build_node\n", "from copy import deepcopy\n", "# register PwCalculation calcjob as a node class\n", - "PwCalculation = build_node({\"path\": \"aiida_quantumespresso.calculations.pw.PwCalculation\"})\n", + "PwCalculation = build_node(\"aiida_quantumespresso.calculations.pw.PwCalculation\")\n", "\n", "#-------------------------------------------------------\n", "relax_pw_paras = deepcopy(pw_paras)\n", diff --git a/docs/source/tutorial/qe.ipynb b/docs/source/tutorial/qe.ipynb index a295f634..3e894273 100644 --- a/docs/source/tutorial/qe.ipynb +++ b/docs/source/tutorial/qe.ipynb @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "id": "bfffb91f", "metadata": {}, "outputs": [ @@ -57,7 +57,7 @@ "Profile" ] }, - "execution_count": 5, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -87,16 +87,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "id": "11e3bca1-dda6-44e9-9585-54feeda7e7db", "metadata": {}, "outputs": [], "source": [ "from aiida_worktree import build_node\n", + "from aiida_quantumespresso.calculations.pw import PwCalculation\n", "\n", "# register node\n", - "ndata = {\"path\": \"aiida_quantumespresso.calculations.pw.PwCalculation\"}\n", - "pw_node = build_node(ndata)" + "pw_node = build_node(PwCalculation)" ] }, { @@ -109,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "id": "e5a9d4bf", "metadata": {}, "outputs": [ @@ -1061,8 +1061,7 @@ "from aiida_worktree.decorator import build_node\n", "\n", "# register node\n", - "ndata = {\"path\": \"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\"}\n", - "pw_relax_node = build_node(ndata)" + "pw_relax_node = build_node(\"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\")" ] }, { @@ -1641,9 +1640,7 @@ "from pprint import pprint\n", "\n", "# register node\n", - "pw_relax_node = build_node(\n", - " {\"path\": \"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\"}\n", - ")\n", + "pw_relax_node = build_node(\"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\")\n", "code = orm.load_code(\"pw-7.2@localhost\")\n", "wt = WorkTree(\"test_pw_relax\")\n", "structure_si = orm.StructureData(ase=bulk(\"Si\"))\n", diff --git a/tests/conftest.py b/tests/conftest.py index ef13e767..67902499 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ def arithmetic_add(): """Generate a node for test.""" arithmetic_add = build_node( - {"path": "aiida.calculations.arithmetic.add.ArithmeticAddCalculation"} + "aiida.calculations.arithmetic.add.ArithmeticAddCalculation" ) return arithmetic_add @@ -173,9 +173,9 @@ def build_workchain(): """Generate a decorated node for test.""" from aiida_worktree import build_node + from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain - ndata = {"path": "aiida.workflows.arithmetic.multiply_add.MultiplyAddWorkChain"} - multiply_add = build_node(ndata) + multiply_add = build_node(MultiplyAddWorkChain) return multiply_add diff --git a/tests/test_build_node.py b/tests/test_build_node.py new file mode 100644 index 00000000..6e4ec005 --- /dev/null +++ b/tests/test_build_node.py @@ -0,0 +1,47 @@ +from aiida_worktree import build_node, Node + + +def test_calcjob(): + """Generate a node for test.""" + from aiida.calculations.arithmetic.add import ArithmeticAddCalculation + + ArithmeticAddNode = build_node(ArithmeticAddCalculation) + assert issubclass(ArithmeticAddNode, Node) + # build from path + ArithmeticAddNode = build_node( + "aiida.calculations.arithmetic.add.ArithmeticAddCalculation" + ) + assert issubclass(ArithmeticAddNode, Node) + + +def test_workchain(): + from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain + + MultiplyAddWorkNode = build_node(MultiplyAddWorkChain) + assert issubclass(MultiplyAddWorkNode, Node) + # build from path + MultiplyAddWorkNode = build_node( + "aiida.workflows.arithmetic.multiply_add.MultiplyAddWorkChain" + ) + assert issubclass(MultiplyAddWorkNode, Node) + + +def test_calcfunction(): + """Generate a node for test.""" + from aiida.engine import calcfunction + + @calcfunction + def add(x, y): + """Calculate the square root of a number.""" + return x + y + + AddNode = build_node(add) + assert issubclass(AddNode, Node) + + +def test_function(): + """Generate a node for test.""" + from scipy.linalg import norm + + AddNode = build_node(norm) + assert issubclass(AddNode, Node) diff --git a/tests/test_engine.py b/tests/test_engine.py index 97a1e526..88901949 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -36,7 +36,7 @@ def test_max_number_jobs(): # Use the calcjob: ArithmeticAddCalculation arithmetic_add = build_node( - {"path": "aiida.calculations.arithmetic.add.ArithmeticAddCalculation"} + "aiida.calculations.arithmetic.add.ArithmeticAddCalculation" ) code = load_code("add@localhost") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 55a6aae0..cf5ddee4 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -11,7 +11,7 @@ def test_pw_relax_protocol(structure_si): # register node pw_relax_node = build_node( - {"path": "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"} + "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain" ) code = orm.load_code("qe-7.2-pw@localhost") wt = WorkTree("test_pw_relax") @@ -35,7 +35,7 @@ def test_pw_relax_protocol_pop(structure_si): # register node pw_relax_node = build_node( - {"path": "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"} + "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain" ) code = orm.load_code("qe-7.2-pw@localhost") wt = WorkTree("test_pw_relax") diff --git a/tests/test_qe.py b/tests/test_qe.py index f26e982d..653cc692 100644 --- a/tests/test_qe.py +++ b/tests/test_qe.py @@ -129,8 +129,9 @@ def test_pw_relax_workchain(structure_si): from aiida.orm import Dict, KpointsData, load_code, load_group # register node - ndata = {"path": "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"} - pw_relax_node = build_node(ndata) + pw_relax_node = build_node( + "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain" + ) @node.calcfunction() def pw_parameters(paras, relax_type):