Skip to content

Commit

Permalink
aggregating new data class
Browse files Browse the repository at this point in the history
  • Loading branch information
floatingCatty committed Dec 19, 2023
1 parent a9f4a75 commit b52ccb9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
10 changes: 10 additions & 0 deletions dptb/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,29 @@ def build_dataset(set_options, common_options):
if type == "ABACUSDataset":
assert "pbc" in set_options, "PBC must be provided in `data_options` when loading ABACUS dataset."
AtomicDataOptions["pbc"] = set_options["pbc"]
if "basis" in common_options:
idp = OrbitalMapper(common_options["basis"])
else:
idp = None
dataset = ABACUSDataset(
root=set_options["root"],
preprocess_dir=set_options["preprocess_dir"],
AtomicData_options=AtomicDataOptions,
type_mapper=idp,
)
elif type == "ABACUSInMemoryDataset":
assert "pbc" in set_options, "PBC must be provided in `data_options` when loading ABACUS dataset."
AtomicDataOptions["pbc"] = set_options["pbc"]
if "basis" in common_options:
idp = OrbitalMapper(common_options["basis"])
else:
idp = None
dataset = ABACUSInMemoryDataset(
root=set_options["root"],
preprocess_dir=set_options["preprocess_dir"],
include_frames=set_options.get("include_frames"),
AtomicData_options=AtomicDataOptions,
type_mapper=idp,
)

# input in common_option for Default Dataset:
Expand Down
20 changes: 11 additions & 9 deletions dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def train_data_sub():
doc_reduce_edge = ""

args = [
Argument("type", str, optional=True, default="DefaultDataset", doc="The type of dataset"),
Argument("root", str, optional=False, doc=doc_root),
Argument("preprocess_path", str, optional=False, doc=doc_preprocess_path),
Argument("file_names", list, optional=False, doc=doc_file_names),
Argument("preprocess_dir", str, optional=False, doc=doc_preprocess_path),
Argument("AtomicData_options", dict, optional=True, default={}, doc="The options for AtomicData class"),
Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc),
Argument("reduce_edge", bool, optional=True, default=True, doc=doc_reduce_edge)
]

doc_train = ""
Expand All @@ -208,10 +208,11 @@ def validation_data_sub():
doc_pbc = ""

args = [
Argument("type", str, optional=True, default="DefaultDataset", doc="The type of dataset"),
Argument("root", str, optional=False, doc=doc_root),
Argument("preprocess_path", str, optional=False, doc=doc_preprocess_path),
Argument("file_names", list, optional=False, doc=doc_file_names),
Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc)
Argument("preprocess_dir", str, optional=False, doc=doc_preprocess_path),
Argument("AtomicData_options", dict, optional=True, default={}, doc="The options for AtomicData class"),
Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc),
]

doc_validation = ""
Expand All @@ -225,10 +226,11 @@ def reference_data_sub():
doc_pbc = ""

args = [
Argument("type", str, optional=True, default="DefaultDataset", doc="The type of dataset"),
Argument("root", str, optional=False, doc=doc_root),
Argument("preprocess_path", str, optional=False, doc=doc_preprocess_path),
Argument("file_names", list, optional=False, doc=doc_file_names),
Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc)
Argument("preprocess_dir", str, optional=False, doc=doc_preprocess_path),
Argument("AtomicData_options", dict, optional=True, default={}, doc="The options for AtomicData class"),
Argument("pbc", [bool, list], optional=True, default=True, doc=doc_pbc),
]

doc_reference = ""
Expand Down

0 comments on commit b52ccb9

Please sign in to comment.