From b52ccb942ff38b8cf80b4a35875ed507aced1bf0 Mon Sep 17 00:00:00 2001 From: zhanghao Date: Tue, 19 Dec 2023 21:01:31 +0800 Subject: [PATCH] aggregating new data class --- dptb/data/build.py | 10 ++++++++++ dptb/utils/argcheck.py | 20 +++++++++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/dptb/data/build.py b/dptb/data/build.py index 62cdcee0..52bd907c 100644 --- a/dptb/data/build.py +++ b/dptb/data/build.py @@ -125,19 +125,29 @@ def build_dataset(set_options, common_options): if type == "ABACUSDataset": assert "pbc" in set_options, "PBC must be provided in `data_options` when loading ABACUS dataset." AtomicDataOptions["pbc"] = set_options["pbc"] + if "basis" in common_options: + idp = OrbitalMapper(common_options["basis"]) + else: + idp = None dataset = ABACUSDataset( root=set_options["root"], preprocess_dir=set_options["preprocess_dir"], AtomicData_options=AtomicDataOptions, + type_mapper=idp, ) elif type == "ABACUSInMemoryDataset": assert "pbc" in set_options, "PBC must be provided in `data_options` when loading ABACUS dataset." AtomicDataOptions["pbc"] = set_options["pbc"] + if "basis" in common_options: + idp = OrbitalMapper(common_options["basis"]) + else: + idp = None dataset = ABACUSInMemoryDataset( root=set_options["root"], preprocess_dir=set_options["preprocess_dir"], include_frames=set_options.get("include_frames"), AtomicData_options=AtomicDataOptions, + type_mapper=idp, ) # input in common_option for Default Dataset: diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index 4a33155a..69c3050f 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -190,11 +190,11 @@ def train_data_sub(): doc_reduce_edge = "" args = [ + Argument("type", str, optional=True, default="DefaultDataset", doc="The type of dataset"), Argument("root", str, optional=False, doc=doc_root), - Argument("preprocess_path", str, optional=False, doc=doc_preprocess_path), - Argument("file_names", list, optional=False, doc=doc_file_names), + Argument("preprocess_dir", str, optional=False, doc=doc_preprocess_path), + Argument("AtomicData_options", dict, optional=True, default={}, doc="The options for AtomicData class"), Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc), - Argument("reduce_edge", bool, optional=True, default=True, doc=doc_reduce_edge) ] doc_train = "" @@ -208,10 +208,11 @@ def validation_data_sub(): doc_pbc = "" args = [ + Argument("type", str, optional=True, default="DefaultDataset", doc="The type of dataset"), Argument("root", str, optional=False, doc=doc_root), - Argument("preprocess_path", str, optional=False, doc=doc_preprocess_path), - Argument("file_names", list, optional=False, doc=doc_file_names), - Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc) + Argument("preprocess_dir", str, optional=False, doc=doc_preprocess_path), + Argument("AtomicData_options", dict, optional=True, default={}, doc="The options for AtomicData class"), + Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc), ] doc_validation = "" @@ -225,10 +226,11 @@ def reference_data_sub(): doc_pbc = "" args = [ + Argument("type", str, optional=True, default="DefaultDataset", doc="The type of dataset"), Argument("root", str, optional=False, doc=doc_root), - Argument("preprocess_path", str, optional=False, doc=doc_preprocess_path), - Argument("file_names", list, optional=False, doc=doc_file_names), - Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc) + Argument("preprocess_dir", str, optional=False, doc=doc_preprocess_path), + Argument("AtomicData_options", dict, optional=True, default={}, doc="The options for AtomicData class"), + Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc), ] doc_reference = ""