Skip to content

Commit

Permalink
add overlap output in feature_to_block (#208)
Browse files Browse the repository at this point in the history
* add overlap output in feature_to_block

* update feature_to_block generally for H, S and D
  • Loading branch information
AsymmetryChou authored Sep 19, 2024
1 parent ac6faa9 commit 6c38085
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions dptb/data/interfaces/ham_to_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,19 +319,25 @@ def block_to_feature(data, idp, blocks=False, overlap_blocks=False, orthogonal=F
# if overlap_blocks:
# data[_keys.EDGE_OVERLAP_KEY] = torch.as_tensor(np.array(edge_overlap), dtype=torch.get_default_dtype())

def feature_to_block(data, idp):
def feature_to_block(data, idp, overlap: bool = False):
idp.get_orbital_maps()
idp.get_orbpair_maps()

has_block = False
if data.get(_keys.NODE_FEATURES_KEY, None) is not None:
node_features = data[_keys.NODE_FEATURES_KEY]
edge_features = data[_keys.EDGE_FEATURES_KEY]
has_block = True
blocks = {}

idp.get_orbital_maps()
idp.get_orbpair_maps()
if not overlap:
if data.get(_keys.NODE_FEATURES_KEY, None) is not None:
node_features = data[_keys.NODE_FEATURES_KEY]
edge_features = data[_keys.EDGE_FEATURES_KEY]
has_block = True
blocks = {}
else:
if data.get(_keys.NODE_OVERLAP_KEY, None) is not None:
node_features = data[_keys.NODE_OVERLAP_KEY]
edge_features = data[_keys.EDGE_OVERLAP_KEY]
has_block = True
blocks = {}
else:
raise KeyError("Overlap features not found in data.")

if has_block:
# get node blocks from node_features
Expand Down

0 comments on commit 6c38085

Please sign in to comment.