From 6c38085d15e38fcf97277de2a6d59942e001dfaa Mon Sep 17 00:00:00 2001 From: Jijie Zou <73353910+AsymmetryChou@users.noreply.github.com> Date: Thu, 19 Sep 2024 17:34:57 +0800 Subject: [PATCH] add overlap output in feature_to_block (#208) * add overlap output in feature_to_block * update feature_to_block generally for H, S and D --- dptb/data/interfaces/ham_to_feature.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/dptb/data/interfaces/ham_to_feature.py b/dptb/data/interfaces/ham_to_feature.py index cc39b982..cfd61a7e 100644 --- a/dptb/data/interfaces/ham_to_feature.py +++ b/dptb/data/interfaces/ham_to_feature.py @@ -319,19 +319,25 @@ def block_to_feature(data, idp, blocks=False, overlap_blocks=False, orthogonal=F # if overlap_blocks: # data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype()) -def feature_to_block(data, idp): +def feature_to_block(data, idp, overlap: bool = False): idp.get_orbital_maps() idp.get_orbpair_maps() has_block = False - if data.get(_keys.NODE_FEATURES_KEY, None) is not None: - node_features = data[_keys.NODE_FEATURES_KEY] - edge_features = data[_keys.EDGE_FEATURES_KEY] - has_block = True - blocks = {} - - idp.get_orbital_maps() - idp.get_orbpair_maps() + if not overlap: + if data.get(_keys.NODE_FEATURES_KEY, None) is not None: + node_features = data[_keys.NODE_FEATURES_KEY] + edge_features = data[_keys.EDGE_FEATURES_KEY] + has_block = True + blocks = {} + else: + if data.get(_keys.NODE_OVERLAP_KEY, None) is not None: + node_features = data[_keys.NODE_OVERLAP_KEY] + edge_features = data[_keys.EDGE_OVERLAP_KEY] + has_block = True + blocks = {} + else: + raise KeyError("Overlap features not found in data.") if has_block: # get node blocks from node_features