Skip to content

Commit

Permalink
PythonJob: check duplicate entry points for data serializer (#140)
Browse files Browse the repository at this point in the history
* check duplicate entry points for the data serializer
* support configuration file for workgraph: `workgraph.json` in the aiida configuration directory
  • Loading branch information
superstar54 authored Jul 10, 2024
1 parent f750392 commit 7192fe6
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 10 deletions.
13 changes: 13 additions & 0 deletions aiida_workgraph/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import json
from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER


def load_config() -> dict:
"""Load the configuration from the config file."""
config_file_path = AIIDA_CONFIG_FOLDER / "workgraph.json"
try:
with config_file_path.open("r") as f:
config = json.load(f)
except FileNotFoundError:
config = {}
return config
42 changes: 39 additions & 3 deletions aiida_workgraph/orm/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,46 @@
from aiida import orm, common
from importlib.metadata import entry_points
from typing import Any
from aiida_workgraph.config import load_config


# Retrieve the entry points for 'aiida.data' and store them in a dictionary
eps = {ep.name: ep for ep in entry_points().get("aiida.data", [])}
def get_serializer_from_entry_points() -> dict:
"""Retrieve the serializer from the entry points."""
# import time

# ts = time.time()
configs = load_config()
excludes = configs.get("excludes", [])
# Retrieve the entry points for 'aiida.data' and store them in a dictionary
eps = {}
for ep in entry_points().get("aiida.data", []):
# split the entry point name by first ".", and check the last part
key = ep.name.split(".", 1)[-1]
# skip key without "." because it is not a module name for a data type
if "." not in key or key in excludes:
continue
eps.setdefault(key, [])
eps[key].append(ep)

# print("Time to load entry points: ", time.time() - ts)
# check if there are duplicates
selects = configs.get("select", {})
for key, value in eps.items():
if len(value) > 1:
if key in selects:
[ep for ep in value if ep.name == selects[key]]
eps[key] = [ep for ep in value if ep.name == selects[key]]
if not eps[key]:
raise ValueError(
f"Entry point {configs['select'][key]} not found for {key}"
)
else:
msg = f"Duplicate entry points for {key}: {[ep.name for ep in value]}"
raise ValueError(msg)
return eps


eps = get_serializer_from_entry_points()


def serialize_to_aiida_nodes(inputs: dict = None) -> dict:
Expand Down Expand Up @@ -53,7 +89,7 @@ def general_serializer(data: Any, check_value=True) -> orm.Node:
# search for the key in the entry points
if ep_key in eps:
try:
new_node = eps[ep_key].load()(data)
new_node = eps[ep_key][0].load()(data)
except Exception as e:
raise ValueError(f"Error in serializing {ep_key}: {e}")
finally:
Expand Down
30 changes: 30 additions & 0 deletions docs/source/built-in/pythonjob.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2365,6 +2365,36 @@
"source": [
"We can see that the `result.txt` file is retrieved from the remote computer and stored in the local repository."
]
},
{
"cell_type": "markdown",
"id": "8d4d935b",
"metadata": {},
"source": [
"## Define your data serializer\n",
"Workgraph search data serializer from the `aiida.data` entry point by the module name and class name (e.g., `ase.atoms.Atoms`). \n",
"\n",
"In order to let the workgraph find the serializer, you must register the AiiDA data with the following format:\n",
"```\n",
"[project.entry-points.\"aiida.data\"]\n",
"abc.ase.atoms.Atoms = \"abc.xyz:MyAtomsData\"\n",
"```\n",
"This will register a data serializer for `ase.atoms.Atoms` data. `abc` is the plugin name, module name is `xyz`, and the AiiDA data class name is `AtomsData`. Learn how to create a AiiDA data [here](https://aiida.readthedocs.io/projects/aiida-core/en/stable/topics/data_types.html#adding-support-for-custom-data-types).\n",
"\n",
"\n",
"### Avoid duplicate data serializer\n",
"If you have multiple plugins that register the same data serializer, the workgraph will raise an error. You can avoid this by selecting the plugin that you want to use in the configuration file.\n",
"\n",
"```json\n",
"{\n",
" \"select\": {\n",
" \"ase.atoms.Atoms\": \"abc.ase.atoms.Atoms\"\n",
" },\n",
"}\n",
"```\n",
"\n",
"Save the configuration file as `workgraph.json` in the aiida configuration directory (by default, `~/.aiida` directory)."
]
}
],
"metadata": {
Expand Down
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"

[project.entry-points."aiida.data"]
"workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData"
"ase.atoms.Atoms" = "aiida_workgraph.orm.atoms:AtomsData"
"builtins.int" = "aiida.orm.nodes.data.int:Int"
"builtins.float" = "aiida.orm.nodes.data.float:Float"
"builtins.str" = "aiida.orm.nodes.data.str:Str"
"builtins.bool" = "aiida.orm.nodes.data.bool:Bool"
"builtins.list"="aiida_workgraph.orm.general_data:List"
"builtins.dict"="aiida_workgraph.orm.general_data:Dict"
"workgraph.ase.atoms.Atoms" = "aiida_workgraph.orm.atoms:AtomsData"
"workgraph.builtins.int" = "aiida.orm.nodes.data.int:Int"
"workgraph.builtins.float" = "aiida.orm.nodes.data.float:Float"
"workgraph.builtins.str" = "aiida.orm.nodes.data.str:Str"
"workgraph.builtins.bool" = "aiida.orm.nodes.data.bool:Bool"
"workgraph.builtins.list"="aiida_workgraph.orm.general_data:List"
"workgraph.builtins.dict"="aiida_workgraph.orm.general_data:Dict"


[project.entry-points."aiida.node"]
Expand Down

0 comments on commit 7192fe6

Please sign in to comment.