Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support building node from a callable (#46)
Browse files Browse the repository at this point in the history
One can build a node from a callable or the path of the callable.
superstar54 committed Apr 18, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 621dbcc commit 11c078b
Showing 15 changed files with 127 additions and 55 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -30,11 +30,14 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
- name: Build web frontend package and widget
run:
npm install
npm run build
python -m build
cd aiida_worktree/widget/
npm install
npm run build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
2 changes: 1 addition & 1 deletion aiida_worktree/__init__.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,6 @@
from .decorator import node, build_node


__version__ = "0.1.6"
__version__ = "0.1.8"

__all__ = ["WorkTree", "Node", "node", "build_node"]
55 changes: 39 additions & 16 deletions aiida_worktree/decorator.py
Original file line number Diff line number Diff line change
@@ -30,45 +30,68 @@ def add_input_recursive(inputs, port, prefix=None):
return inputs


def build_node(ndata):
def build_node(executor, outputs=None):
"""Build node from executor."""
from aiida_worktree.worktree import WorkTree

if isinstance(ndata, WorkTree):
return build_node_from_worktree(ndata)
elif "path" in ndata:
return build_node_from_AiiDA(ndata)
if isinstance(executor, WorkTree):
return build_node_from_worktree(executor)
elif isinstance(executor, str):
(
path,
executor_name,
) = executor.rsplit(".", 1)
executor, _ = get_executor({"path": path, "name": executor_name})
if callable(executor):
return build_node_from_callable(executor, outputs=outputs)


def build_node_from_AiiDA(ndata):
"""Register a node from a AiiDA component.
For example: CalcJob, WorkChain, CalcFunction, WorkFunction."""
from aiida_worktree.node import Node
def build_node_from_callable(executor, outputs=None):
"""Build node from a callable object."""
import inspect

path, executor_name, = ndata.pop(
"path"
).rsplit(".", 1)
ndata["executor"] = {"path": path, "name": executor_name}
executor, type = get_executor(ndata["executor"])
# print(executor)
ndata = {}
if inspect.isfunction(executor):
# calcfunction and workfunction
if getattr(executor, "node_class", False):
if executor.node_class is CalcFunctionNode:
ndata["node_type"] = "calcfunction"
elif executor.node_class is WorkFunctionNode:
ndata["node_type"] = "workfunction"
ndata["executor"] = executor
return build_node_from_AiiDA(ndata)
else:
ndata["node_type"] = "normal"
ndata["executor"] = executor
return build_node_from_function(executor, outputs=outputs)
else:
if issubclass(executor, CalcJob):
ndata["node_type"] = "calcjob"
ndata["executor"] = executor
return build_node_from_AiiDA(ndata)
elif issubclass(executor, WorkChain):
ndata["node_type"] = "workchain"
ndata["executor"] = executor
return build_node_from_AiiDA(ndata)
else:
ndata["node_type"] = "normal"
ndata["executor"] = executor


def build_node_from_function(executor, outputs=None):
"""Build node from function."""
return NodeDecoratorCollection.decorator_node(outputs=outputs)(executor).node


def build_node_from_AiiDA(ndata):
"""Register a node from a AiiDA component.
For example: CalcJob, WorkChain, CalcFunction, WorkFunction."""
from aiida_worktree.node import Node

# print(executor)
inputs = []
outputs = []
executor = ndata["executor"]
spec = executor.spec()
for _key, port in spec.inputs.ports.items():
add_input_recursive(inputs, port)
@@ -89,7 +112,7 @@ def build_node_from_AiiDA(ndata):
ndata["kwargs"] = kwargs
ndata["inputs"] = inputs
ndata["outputs"] = outputs
ndata["identifier"] = ndata.pop("identifier", ndata["executor"]["name"])
ndata["identifier"] = ndata.pop("identifier", ndata["executor"].__name__)
# TODO In order to reload the WorkTree from process, "is_pickle" should be True
# so I pickled the function here, but this is not necessary
# we need to update the node_graph to support the path and name of the function
17 changes: 11 additions & 6 deletions docs/source/concept/node.ipynb
Original file line number Diff line number Diff line change
@@ -127,19 +127,24 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build from AiiDA\n",
"One can build a node from an already existing AiiDA component: `calcfunction`, `workfunction`, `calcjob`, `Workchain` with the `build_node` function."
"## Build from Callable\n",
"\n",
"One can build a node from an already existing Python function, or AiiDA component: `calcfunction`, `workfunction`, `calcjob`, `Workchain` with the `build_node` function."
]
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from aiida_worktree import build_node\n",
"ndata = {\"path\": \"aiida.calculations.arithmetic.add.ArithmeticAddCalculation\"}\n",
"AddNode = build_node(ndata)"
"\n",
"from scipy.linalg import norm\n",
"NormNode = build_node(norm)\n",
"\n",
"from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n",
"AddNode = build_node(ArithmeticAddCalculation)\n"
]
},
{
@@ -151,7 +156,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 6,
"metadata": {},
"outputs": [
{
2 changes: 1 addition & 1 deletion docs/source/howto/append_worktree.ipynb
Original file line number Diff line number Diff line change
@@ -151,7 +151,7 @@
"source": [
"from aiida_worktree import WorkTree, build_node\n",
"\n",
"PwRelaxChainNode = build_node({'path': 'aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain'})\n",
"PwRelaxChainNode = build_node('aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain')\n",
"\n",
"wt = WorkTree('Electronic Structure')\n",
"relax_node = wt.nodes.new(PwRelaxChainNode, name='relax')\n",
4 changes: 1 addition & 3 deletions docs/source/howto/protocol.ipynb
Original file line number Diff line number Diff line change
@@ -102,9 +102,7 @@
"from pprint import pprint\n",
"\n",
"# register node\n",
"pw_relax_node = build_node(\n",
" {\"path\": \"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\"}\n",
")\n",
"pw_relax_node = build_node(\"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\")\n",
"code = orm.load_code(\"pw-7.2@localhost\")\n",
"wt = WorkTree(\"test_pw_relax\")\n",
"structure_si = orm.StructureData(ase=bulk(\"Si\"))\n",
4 changes: 1 addition & 3 deletions docs/source/howto/queue.ipynb
Original file line number Diff line number Diff line change
@@ -83,9 +83,7 @@
"from aiida.orm import Int\n",
"\n",
"# Use the calcjob: ArithmeticAddCalculation\n",
"arithmetic_add = build_node(\n",
" {\"path\": \"aiida.calculations.arithmetic.add.ArithmeticAddCalculation\"}\n",
")\n",
"arithmetic_add = build_node(\"aiida.calculations.arithmetic.add.ArithmeticAddCalculation\")\n",
"code = load_code(\"add@localhost\")\n",
"\n",
"wt = WorkTree(\"test_max_number_jobs\")\n",
6 changes: 3 additions & 3 deletions docs/source/quick_start.ipynb
Original file line number Diff line number Diff line change
@@ -459,8 +459,8 @@
"outputs": [],
"source": [
"from aiida_worktree import build_node\n",
"ndata = {\"path\": \"aiida.calculations.arithmetic.add.ArithmeticAddCalculation\"}\n",
"ArithmeticAddCalculationNode = build_node(ndata)"
"from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n",
"ArithmeticAddCalculationNode = build_node(ArithmeticAddCalculation)"
]
},
{
@@ -1178,7 +1178,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
"version": "3.10.0"
},
"vscode": {
"interpreter": {
4 changes: 2 additions & 2 deletions docs/source/tutorial/eos.ipynb
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@
" \"\"\"Run the scf calculation for each structure.\"\"\"\n",
" from aiida_worktree import WorkTree, build_node\n",
" # register PwCalculation calcjob as a node class\n",
" PwCalculation = build_node({\"path\": \"aiida_quantumespresso.calculations.pw.PwCalculation\"})\n",
" PwCalculation = build_node(\"aiida_quantumespresso.calculations.pw.PwCalculation\")\n",
" wt = WorkTree()\n",
" for key, structure in structures.items():\n",
" pw1 = wt.nodes.new(PwCalculation, name=f\"pw1_{key}\", structure=structure)\n",
@@ -348,7 +348,7 @@
"from aiida_worktree import WorkTree, build_node\n",
"from copy import deepcopy\n",
"# register PwCalculation calcjob as a node class\n",
"PwCalculation = build_node({\"path\": \"aiida_quantumespresso.calculations.pw.PwCalculation\"})\n",
"PwCalculation = build_node(\"aiida_quantumespresso.calculations.pw.PwCalculation\")\n",
"\n",
"#-------------------------------------------------------\n",
"relax_pw_paras = deepcopy(pw_paras)\n",
19 changes: 8 additions & 11 deletions docs/source/tutorial/qe.ipynb
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"id": "bfffb91f",
"metadata": {},
"outputs": [
@@ -57,7 +57,7 @@
"Profile<uuid='bcf9e395e4bf4b64a0a705d8659c0a9c' name='default'>"
]
},
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
@@ -87,16 +87,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"id": "11e3bca1-dda6-44e9-9585-54feeda7e7db",
"metadata": {},
"outputs": [],
"source": [
"from aiida_worktree import build_node\n",
"from aiida_quantumespresso.calculations.pw import PwCalculation\n",
"\n",
"# register node\n",
"ndata = {\"path\": \"aiida_quantumespresso.calculations.pw.PwCalculation\"}\n",
"pw_node = build_node(ndata)"
"pw_node = build_node(PwCalculation)"
]
},
{
@@ -109,7 +109,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 3,
"id": "e5a9d4bf",
"metadata": {},
"outputs": [
@@ -1061,8 +1061,7 @@
"from aiida_worktree.decorator import build_node\n",
"\n",
"# register node\n",
"ndata = {\"path\": \"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\"}\n",
"pw_relax_node = build_node(ndata)"
"pw_relax_node = build_node(\"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\")"
]
},
{
@@ -1641,9 +1640,7 @@
"from pprint import pprint\n",
"\n",
"# register node\n",
"pw_relax_node = build_node(\n",
" {\"path\": \"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\"}\n",
")\n",
"pw_relax_node = build_node(\"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain\")\n",
"code = orm.load_code(\"pw-7.2@localhost\")\n",
"wt = WorkTree(\"test_pw_relax\")\n",
"structure_si = orm.StructureData(ase=bulk(\"Si\"))\n",
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ def arithmetic_add():
"""Generate a node for test."""

arithmetic_add = build_node(
{"path": "aiida.calculations.arithmetic.add.ArithmeticAddCalculation"}
"aiida.calculations.arithmetic.add.ArithmeticAddCalculation"
)
return arithmetic_add

@@ -173,9 +173,9 @@ def build_workchain():
"""Generate a decorated node for test."""

from aiida_worktree import build_node
from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain

ndata = {"path": "aiida.workflows.arithmetic.multiply_add.MultiplyAddWorkChain"}
multiply_add = build_node(ndata)
multiply_add = build_node(MultiplyAddWorkChain)

return multiply_add

47 changes: 47 additions & 0 deletions tests/test_build_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from aiida_worktree import build_node, Node


def test_calcjob():
"""Generate a node for test."""
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation

ArithmeticAddNode = build_node(ArithmeticAddCalculation)
assert issubclass(ArithmeticAddNode, Node)
# build from path
ArithmeticAddNode = build_node(
"aiida.calculations.arithmetic.add.ArithmeticAddCalculation"
)
assert issubclass(ArithmeticAddNode, Node)


def test_workchain():
from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain

MultiplyAddWorkNode = build_node(MultiplyAddWorkChain)
assert issubclass(MultiplyAddWorkNode, Node)
# build from path
MultiplyAddWorkNode = build_node(
"aiida.workflows.arithmetic.multiply_add.MultiplyAddWorkChain"
)
assert issubclass(MultiplyAddWorkNode, Node)


def test_calcfunction():
"""Generate a node for test."""
from aiida.engine import calcfunction

@calcfunction
def add(x, y):
"""Calculate the square root of a number."""
return x + y

AddNode = build_node(add)
assert issubclass(AddNode, Node)


def test_function():
"""Generate a node for test."""
from scipy.linalg import norm

AddNode = build_node(norm)
assert issubclass(AddNode, Node)
2 changes: 1 addition & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ def test_max_number_jobs():

# Use the calcjob: ArithmeticAddCalculation
arithmetic_add = build_node(
{"path": "aiida.calculations.arithmetic.add.ArithmeticAddCalculation"}
"aiida.calculations.arithmetic.add.ArithmeticAddCalculation"
)
code = load_code("add@localhost")

4 changes: 2 additions & 2 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ def test_pw_relax_protocol(structure_si):

# register node
pw_relax_node = build_node(
{"path": "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"}
"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"
)
code = orm.load_code("qe-7.2-pw@localhost")
wt = WorkTree("test_pw_relax")
@@ -35,7 +35,7 @@ def test_pw_relax_protocol_pop(structure_si):

# register node
pw_relax_node = build_node(
{"path": "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"}
"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"
)
code = orm.load_code("qe-7.2-pw@localhost")
wt = WorkTree("test_pw_relax")
5 changes: 3 additions & 2 deletions tests/test_qe.py
Original file line number Diff line number Diff line change
@@ -129,8 +129,9 @@ def test_pw_relax_workchain(structure_si):
from aiida.orm import Dict, KpointsData, load_code, load_group

# register node
ndata = {"path": "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"}
pw_relax_node = build_node(ndata)
pw_relax_node = build_node(
"aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"
)

@node.calcfunction()
def pw_parameters(paras, relax_type):

0 comments on commit 11c078b

Please sign in to comment.