diff --git a/datatree/datatree.py b/datatree/datatree.py index c86c2e2e..e556adbc 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -62,6 +62,8 @@ from xarray.core.merge import CoercibleValue from xarray.core.types import ErrorOptions + from datatree.treenode import T_PathLike + # """ # DEVELOPERS' NOTE # ---------------- @@ -76,9 +78,6 @@ # """ -T_Path = Union[str, NodePath] - - def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: if isinstance(data, DataArray): ds = data.to_dataset() @@ -848,7 +847,7 @@ def get( else: return default - def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: + def __getitem__(self: DataTree, key: T_PathLike) -> DataTree | DataArray: """ Access child nodes, variables, or coordinates stored anywhere in this tree. @@ -903,7 +902,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None: def __setitem__( self, - key: str, + key: T_PathLike, value: Any, ) -> None: """ @@ -1034,7 +1033,7 @@ def drop_nodes( @classmethod def from_dict( cls, - d: MutableMapping[str, Dataset | DataArray | DataTree | None], + d: MutableMapping[T_PathLike, Dataset | DataArray | DataTree | None], name: Optional[str] = None, ) -> DataTree: """ @@ -1442,7 +1441,7 @@ def merge(self, datatree: DataTree) -> DataTree: """Merge all the leaves of a second DataTree into this one.""" raise NotImplementedError - def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree: + def merge_child_nodes(self, *paths, new_path: T_PathLike) -> DataTree: """Merge a set of child nodes into a single new node.""" raise NotImplementedError diff --git a/datatree/mapping.py b/datatree/mapping.py index 34e227d3..4e577a8f 100644 --- a/datatree/mapping.py +++ b/datatree/mapping.py @@ -4,7 +4,7 @@ import sys from itertools import repeat from textwrap import dedent -from typing import TYPE_CHECKING, Callable, Tuple +from typing import TYPE_CHECKING, Callable, Tuple, Union from xarray import DataArray, Dataset @@ -228,7 +228,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: original_root_path = first_tree.path result_trees = [] for i in range(num_return_values): - out_tree_contents = {} + out_tree_contents: dict[str, Union[None, Dataset]] = {} for n in first_tree.subtree: p = n.path if p in out_data_objects.keys(): diff --git a/datatree/treenode.py b/datatree/treenode.py index 1689d261..a0b8acdd 100644 --- a/datatree/treenode.py +++ b/datatree/treenode.py @@ -17,8 +17,12 @@ from xarray.core.utils import Frozen, is_dict_like if TYPE_CHECKING: + from os import PathLike + from xarray.core.types import T_DataArray + T_PathLike = Union[str, PathLike] + class InvalidTreeError(Exception): """Raised when user attempts to create an invalid tree in some way.""" @@ -445,14 +449,13 @@ def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]: # TODO `._walk` method to be called by both `_get_item` and `_set_item` - def _get_item(self: Tree, path: str | NodePath) -> Union[Tree, T_DataArray]: + def _get_item(self: Tree, path: T_PathLike) -> Union[Tree, T_DataArray]: """ Returns the object lying at the given path. Raises a KeyError if there is no object at the given path. """ - if isinstance(path, str): - path = NodePath(path) + path = NodePath(path) if path.root: current_node = self.root @@ -487,7 +490,7 @@ def _set(self: Tree, key: str, val: Tree) -> None: def _set_item( self: Tree, - path: str | NodePath, + path: T_PathLike, item: Union[Tree, T_DataArray], new_nodes_along_path: bool = False, allow_overwrite: bool = True, @@ -513,8 +516,7 @@ def _set_item( If node cannot be reached, and new_nodes_along_path=False. Or if a node already exists at the specified path, and allow_overwrite=False. """ - if isinstance(path, str): - path = NodePath(path) + path = NodePath(path) if not path.name: raise ValueError("Can't set an item under a path which has no name") diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index 2f6e4f88..6b81230a 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -23,6 +23,9 @@ v0.0.14 (unreleased) New Features ~~~~~~~~~~~~ +- Allow passing :py:class:`os.PathLike` objects as paths to nodes in addition to strings. (:pull:`282`) + By `Tom Nicholas `_. + Breaking changes ~~~~~~~~~~~~~~~~