From 1bfae67ab6533c9dddf8066a40919c56bb7c53b5 Mon Sep 17 00:00:00 2001 From: zjgemi Date: Thu, 17 Oct 2024 15:22:49 +0800 Subject: [PATCH] fix: HDF5Datasets with grouped slices fix: None in HDF5Datasets Signed-off-by: zjgemi --- src/dflow/python/opio.py | 2 ++ src/dflow/python/utils.py | 47 ++++++++++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/dflow/python/opio.py b/src/dflow/python/opio.py index aceed954..2ee946c3 100644 --- a/src/dflow/python/opio.py +++ b/src/dflow/python/opio.py @@ -34,6 +34,8 @@ def __deepcopy__(self, memo=None): return self def get_data(self): + if self.dataset.attrs.get("type") == "null": + return None data = self.dataset[()] if self.dataset.attrs.get("dtype") == "utf-8": data = data.decode("utf-8") diff --git a/src/dflow/python/utils.py b/src/dflow/python/utils.py index b9c7a164..c4204752 100644 --- a/src/dflow/python/utils.py +++ b/src/dflow/python/utils.py @@ -98,19 +98,22 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp", if os.path.isfile(art_path): path_object = [art_path] assert isinstance(path_object, list) - res = None + res = [] for path in path_object: f = h5py.File(path, "r") datasets = {k: HDF5Dataset(f, k) for k in f.keys()} - datasets = expand(datasets) - if isinstance(datasets, list): - if res is None: - res = [] - res += datasets - elif isinstance(datasets, dict): - if res is None: - res = {} + if set(datasets.keys()) == {str(i) for i in range(len(datasets))} \ + and isinstance(res, list): + # concat when all datasets are lists + res += expand(datasets) + else: + # merge otherwise + if isinstance(res, list): + res = flatten(res) res.update(datasets) + + if isinstance(res, dict): + res = expand(res) res = get_slices(res, slices) else: path_object = get_slices(path_object, slices) @@ -211,8 +214,15 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp", os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True) h5_name = "%s.h5" % uuid.uuid4() h5_path = '%s/outputs/artifacts/%s/%s' % (data_root, name, h5_name) + if isinstance(slices, list): + # merge lists + assert isinstance(value, list) and len(slices) == len(value) + items = [(str(s), v) for s, v in zip(slices, value)] + slices = 0 + else: + items = flatten(value).items() with h5py.File(h5_path, "w") as f: - for s, v in flatten(value).items(): + for s, v in items: if isinstance(v, Path): if v.is_file(): try: @@ -240,6 +250,9 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp", elif isinstance(v, HDF5Dataset): d = f.create_dataset(s, data=v.dataset[()]) d.attrs.update(v.dataset.attrs) + elif v is None: + d = f.create_dataset(s, data="") + d.attrs["type"] = "null" else: d = f.create_dataset(s, data=v) d.attrs["type"] = "data" @@ -406,6 +419,16 @@ def absolutize(path): return {k: absolutize(p) for k, p in path.items()} +def absolutize_hdf5(obj): + if isinstance(obj, Path): + return obj.absolute() + if isinstance(obj, list): + return [absolutize(p) for p in obj] + if isinstance(obj, dict): + return {k: absolutize(p) for k, p in obj.items()} + return obj + + def sigalrm_handler(signum, frame): raise TimeoutError("Timeout") @@ -421,8 +444,10 @@ def try_to_execute(input, slice_dir, op_obj, output_sign, cwd, timeout=None): try: output = op_obj.execute(input) for n, s in output_sign.items(): - if isinstance(s, Artifact): + if isinstance(s, Artifact) and s.type != HDF5Datasets: output[n] = absolutize(output[n]) + elif isinstance(s, Artifact) and s.type == HDF5Datasets: + output[n] = absolutize_hdf5(output[n]) os.chdir(cwd) return output, None except Exception as e: