Skip to content

Commit

Permalink
Merge pull request #866 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: HDF5Datasets with grouped slices
  • Loading branch information
zjgemi authored Oct 17, 2024
2 parents 7881a34 + 1bfae67 commit 4e26955
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/dflow/python/opio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __deepcopy__(self, memo=None):
return self

def get_data(self):
if self.dataset.attrs.get("type") == "null":
return None
data = self.dataset[()]
if self.dataset.attrs.get("dtype") == "utf-8":
data = data.decode("utf-8")
Expand Down
47 changes: 36 additions & 11 deletions src/dflow/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,22 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
if os.path.isfile(art_path):
path_object = [art_path]
assert isinstance(path_object, list)
res = None
res = []
for path in path_object:
f = h5py.File(path, "r")
datasets = {k: HDF5Dataset(f, k) for k in f.keys()}
datasets = expand(datasets)
if isinstance(datasets, list):
if res is None:
res = []
res += datasets
elif isinstance(datasets, dict):
if res is None:
res = {}
if set(datasets.keys()) == {str(i) for i in range(len(datasets))} \
and isinstance(res, list):
# concat when all datasets are lists
res += expand(datasets)
else:
# merge otherwise
if isinstance(res, list):
res = flatten(res)
res.update(datasets)

if isinstance(res, dict):
res = expand(res)
res = get_slices(res, slices)
else:
path_object = get_slices(path_object, slices)
Expand Down Expand Up @@ -211,8 +214,15 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
h5_name = "%s.h5" % uuid.uuid4()
h5_path = '%s/outputs/artifacts/%s/%s' % (data_root, name, h5_name)
if isinstance(slices, list):
# merge lists
assert isinstance(value, list) and len(slices) == len(value)
items = [(str(s), v) for s, v in zip(slices, value)]
slices = 0
else:
items = flatten(value).items()
with h5py.File(h5_path, "w") as f:
for s, v in flatten(value).items():
for s, v in items:
if isinstance(v, Path):
if v.is_file():
try:
Expand Down Expand Up @@ -240,6 +250,9 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
elif isinstance(v, HDF5Dataset):
d = f.create_dataset(s, data=v.dataset[()])
d.attrs.update(v.dataset.attrs)
elif v is None:
d = f.create_dataset(s, data="")
d.attrs["type"] = "null"
else:
d = f.create_dataset(s, data=v)
d.attrs["type"] = "data"
Expand Down Expand Up @@ -406,6 +419,16 @@ def absolutize(path):
return {k: absolutize(p) for k, p in path.items()}


def absolutize_hdf5(obj):
if isinstance(obj, Path):
return obj.absolute()
if isinstance(obj, list):
return [absolutize(p) for p in obj]
if isinstance(obj, dict):
return {k: absolutize(p) for k, p in obj.items()}
return obj


def sigalrm_handler(signum, frame):
raise TimeoutError("Timeout")

Expand All @@ -421,8 +444,10 @@ def try_to_execute(input, slice_dir, op_obj, output_sign, cwd, timeout=None):
try:
output = op_obj.execute(input)
for n, s in output_sign.items():
if isinstance(s, Artifact):
if isinstance(s, Artifact) and s.type != HDF5Datasets:
output[n] = absolutize(output[n])
elif isinstance(s, Artifact) and s.type == HDF5Datasets:
output[n] = absolutize_hdf5(output[n])
os.chdir(cwd)
return output, None
except Exception as e:
Expand Down

0 comments on commit 4e26955

Please sign in to comment.