From 47a1c4e4d802100236232ae1aa3f53c9f2f752a9 Mon Sep 17 00:00:00 2001 From: Shaun Weatherly Date: Mon, 20 Jan 2025 11:51:44 -0500 Subject: [PATCH 1/3] Indexing fixes, optimization, and etc. --- src/quemb/molbe/autofrag.py | 100 ++++++++---- src/quemb/molbe/fragment.py | 3 +- tests/fragmentation_test.py | 312 +++++++++--------------------------- 3 files changed, 145 insertions(+), 270 deletions(-) diff --git a/src/quemb/molbe/autofrag.py b/src/quemb/molbe/autofrag.py index 4e65dfc..2228ad3 100644 --- a/src/quemb/molbe/autofrag.py +++ b/src/quemb/molbe/autofrag.py @@ -1,11 +1,12 @@ # Author: Oinam Romesh Meitei, Shaun Weatherly +from copy import deepcopy from typing import Any import networkx as nx import numpy as np from attrs import define -from networkx import single_source_all_shortest_paths # type: ignore[attr-defined] +from networkx import shortest_path # type: ignore[attr-defined] from numpy.linalg import norm from pyscf import gto @@ -56,8 +57,9 @@ class FragmentMap: ebe_weights: list[tuple] sites: list[tuple] dnames: list - center_atoms: list[tuple[str, ...]] - edge_atoms: list[tuple[str, ...]] + fragment_atoms: list[tuple[int, ...]] + center_atoms: list[tuple[int, ...]] + edge_atoms: list[tuple[int, ...]] adjacency_mat: np.ndarray adjacency_graph: nx.Graph @@ -75,16 +77,30 @@ def remove_nonnunique_frags(self, natm: int) -> None: such that all fragments are guaranteed to be distinct sets. """ for _ in range(0, natm): + subsets = set() for adx, basa in enumerate(self.fsites): for bdx, basb in enumerate(self.fsites): if adx == bdx: pass elif set(basb).issubset(set(basa)): - tmp = set(self.center[adx] + self.center[bdx]) - self.center[adx] = tuple(tmp) - del self.center[bdx] - del self.fsites[bdx] - del self.fs[bdx] + subsets.add(bdx) + self.center[adx] = tuple( + set(self.center[adx] + deepcopy(self.center[bdx])) + ) + self.center_atoms[adx] = tuple( + set( + self.center_atoms[adx] + + deepcopy(self.center_atoms[bdx]) + ) + ) + if subsets: + sorted_subsets = sorted(subsets, reverse=True) + for bdx in sorted_subsets: + del self.center[bdx] + del self.fsites[bdx] + del self.fs[bdx] + del self.center_atoms[bdx] + del self.fragment_atoms[bdx] return None @@ -104,6 +120,7 @@ def graphgen( frag_prefix: str = "f", connectivity: str = "euclidean", iao_valence_basis: str | None = None, + cutoff: float = 20.0, ) -> FragmentMap: """Generate fragments via adjacency graph. @@ -175,6 +192,7 @@ def graphgen( ebe_weights=list(tuple()), sites=list(tuple()), dnames=list(), + fragment_atoms=list(), center_atoms=list(), edge_atoms=list(), adjacency_mat=np.zeros((natm, natm), np.float64), @@ -211,39 +229,49 @@ def graphgen( ** 2 ) fragment_map.adjacency_mat[adx, bdx] = dr - fragment_map.adjacency_graph.add_edge(adx, bdx, weight=dr) + if dr <= cutoff: + fragment_map.adjacency_graph.add_edge(adx, bdx, weight=dr) # For a given center site (adx), find the set of shortest # paths to all other sites. The number of nodes visited # on that path gives the degree of separation of the # sites. for adx, map in adx_map.items(): - fragment_map.center_atoms.append(tuple()) - fsites_temp = fragment_map.sites[adx] + fragment_map.center_atoms.append((adx,)) + fragment_map.center.append(deepcopy(fragment_map.sites[adx])) + fsites_temp = deepcopy(fragment_map.sites[adx]) + fatoms_temp = [adx] fs_temp = [] - fs_temp.append(fragment_map.sites[adx]) - map["shortest_paths"] = dict( - single_source_all_shortest_paths( - fragment_map.adjacency_graph, - source=adx, - weight=lambda a, b, _: ( - fragment_map.adjacency_graph[a][b]["weight"] - ), - method="dijkstra", - ) - ) + fs_temp.append(deepcopy(fragment_map.sites[adx])) + + for bdx, _ in adx_map.items(): + if fragment_map.adjacency_graph.has_edge(adx, bdx): + map["shortest_paths"].update( + { + bdx: shortest_path( + fragment_map.adjacency_graph, + source=adx, + target=bdx, + weight=lambda a, b, _: ( + fragment_map.adjacency_graph[a][b]["weight"] + ), + method="dijkstra", + ) + } + ) # If the degree of separation is smaller than the *n* # in your fragment type, BE*n*, then that site is appended to # the set of fragment sites for adx. for bdx, path in map["shortest_paths"].items(): - if 0 < (len(path[0]) - 1) < fragment_type_order: - fsites_temp = tuple(fsites_temp + fragment_map.sites[bdx]) - fs_temp.append(tuple(fragment_map.sites[bdx])) + if 0 < (len(path) - 1) < fragment_type_order: + fsites_temp = fsites_temp + deepcopy(fragment_map.sites[bdx]) + fs_temp.append(deepcopy(fragment_map.sites[bdx])) + fatoms_temp.append(bdx) fragment_map.fsites.append(tuple(fsites_temp)) fragment_map.fs.append(tuple(fs_temp)) - fragment_map.center.append(tuple(fragment_map.sites[adx])) + fragment_map.fragment_atoms.append(tuple(fatoms_temp)) elif connectivity.lower() in ["resistance_distance", "resistance"]: raise NotImplementedError("Work in progress...") @@ -260,22 +288,26 @@ def graphgen( # Define the 'edges' for fragment A as the intersect of its sites # with the set of all center sites outside of A: for adx, fs in enumerate(fragment_map.fs): - edge: set[tuple] = set() + edge_temp: set[tuple] = set() + eatoms_temp: set = set() for bdx, center in enumerate(fragment_map.center): if adx == bdx: pass else: - overlap = set(fs).intersection(set((center,))) - if overlap: - edge = edge.union(overlap) - fragment_map.edge.append(tuple(edge)) + for f in fs: + overlap = set(f).intersection(set(center)) + if overlap: + f_temp = set(fragment_map.fragment_atoms[adx]) + c_temp = set(fragment_map.center_atoms[bdx]) + edge_temp.add(tuple(overlap)) + eatoms_temp.add(i for i in f_temp.intersection(c_temp)) + fragment_map.edge.append(tuple(edge_temp)) + fragment_map.edge_atoms.append(tuple(eatoms_temp)) # Update relative center site indices (centerf_idx) and weights # for center site contributions to the energy (ebe_weights): for adx, center in enumerate(fragment_map.center): - centerf_idx = tuple( - set([fragment_map.fsites[adx].index(cdx) for cdx in center]) - ) + centerf_idx = tuple([fragment_map.fsites[adx].index(cdx) for cdx in center]) ebe_weight = (1.0, tuple(centerf_idx)) fragment_map.centerf_idx.append(centerf_idx) fragment_map.ebe_weights.append(ebe_weight) diff --git a/src/quemb/molbe/fragment.py b/src/quemb/molbe/fragment.py index c6b05c9..7bccd8a 100644 --- a/src/quemb/molbe/fragment.py +++ b/src/quemb/molbe/fragment.py @@ -109,7 +109,8 @@ def __init__( self.fsites = fragment_map.fsites self.edge = fragment_map.edge self.center = fragment_map.center - # self.edge_idx = fragment_map["edge"] + self.Frag_atom = fragment_map.fragment_atoms + self.center_atom = fragment_map.center_atoms self.centerf_idx = fragment_map.centerf_idx self.ebe_weight = fragment_map.ebe_weights self.Nfrag = len(self.fsites) diff --git a/tests/fragmentation_test.py b/tests/fragmentation_test.py index a5f3196..2a1507c 100644 --- a/tests/fragmentation_test.py +++ b/tests/fragmentation_test.py @@ -673,11 +673,18 @@ def test_graphgen_h_linear_be2(self): (5, 4, 6), (6, 5, 7), ], - "edge": [((2,),), ((3,),), ((2,), (4,)), ((3,), (5,)), ((4,),), ((5,),)], + "edge": [ + ((2,),), + ((1,), (3,)), + ((2,), (4,)), + ((3,), (5,)), + ((6,), (4,)), + ((5,),), + ], "center": [(0, 1), (2,), (3,), (4,), (5,), (6, 7)], - "centerf_idx": [(0, 1), (0,), (0,), (0,), (0,), (0, 2)], + "centerf_idx": [(1, 0), (0,), (0,), (0,), (0,), (0, 2)], "ebe_weight": [ - (1.0, (0, 1)), + (1.0, (1, 0)), (1.0, (0,)), (1.0, (0,)), (1.0, (0,)), @@ -711,11 +718,16 @@ def test_graphgen_h_linear_be3(self): (4, 2, 3, 5, 6), (5, 3, 4, 6, 7), ], - "edge": [((3,), (4,)), ((4,),), ((3,),), ((3,), (4,))], + "edge": [ + ((3,), (4,)), + ((1,), (2,), (4,), (5,)), + ((6,), (2,), (3,), (5,)), + ((3,), (4,)), + ], "center": [(0, 1, 2), (3,), (4,), (5, 6, 7)], - "centerf_idx": [(0, 1, 2), (0,), (0,), (0, 3, 4)], + "centerf_idx": [(1, 2, 0), (0,), (0,), (0, 3, 4)], "ebe_weight": [ - (1.0, (0, 1, 2)), + (1.0, (1, 2, 0)), (1.0, (0,)), (1.0, (0,)), (1.0, (0, 3, 4)), @@ -905,132 +917,52 @@ def test_graphgen_octane_be2(self): "fsites": [ (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 13, 19, 20, 21, 22, 23), (5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 10, 12, 14, 15, 16, 17, 18), - (10, 5, 6, 7, 8, 9, 27, 40), - (11, 0, 1, 2, 3, 4, 26, 41), - (12, 5, 6, 7, 8, 9, 25, 38), - (13, 0, 1, 2, 3, 4, 24, 39), (14, 15, 16, 17, 18, 5, 6, 7, 8, 9, 24, 26, 28, 29, 30, 31, 32), (19, 20, 21, 22, 23, 0, 1, 2, 3, 4, 25, 27, 33, 34, 35, 36, 37), - (24, 13, 14, 15, 16, 17, 18, 52), - (25, 12, 19, 20, 21, 22, 23, 53), - (26, 11, 14, 15, 16, 17, 18, 54), - (27, 10, 19, 20, 21, 22, 23, 55), (28, 29, 30, 31, 32, 14, 15, 16, 17, 18, 38, 40, 42, 43, 44, 45, 46), (33, 34, 35, 36, 37, 19, 20, 21, 22, 23, 39, 41, 47, 48, 49, 50, 51), - (38, 12, 28, 29, 30, 31, 32), - (39, 13, 33, 34, 35, 36, 37), - (40, 10, 28, 29, 30, 31, 32), - (41, 11, 33, 34, 35, 36, 37), (42, 43, 44, 45, 46, 28, 29, 30, 31, 32, 52, 54, 57), (47, 48, 49, 50, 51, 33, 34, 35, 36, 37, 53, 55, 56), - (52, 24, 42, 43, 44, 45, 46), - (53, 25, 47, 48, 49, 50, 51), - (54, 26, 42, 43, 44, 45, 46), - (55, 27, 47, 48, 49, 50, 51), ], "edge": [ - ((5, 6, 7, 8, 9), (13,), (19, 20, 21, 22, 23), (11,)), - ((12,), (0, 1, 2, 3, 4), (10,), (14, 15, 16, 17, 18)), - ((5, 6, 7, 8, 9), (40,), (27,)), - ((41,), (0, 1, 2, 3, 4), (26,)), - ((5, 6, 7, 8, 9), (25,), (38,)), - ((24,), (0, 1, 2, 3, 4), (39,)), - ((5, 6, 7, 8, 9), (28, 29, 30, 31, 32), (24,), (26,)), - ((0, 1, 2, 3, 4), (25,), (27,), (33, 34, 35, 36, 37)), - ((52,), (13,), (14, 15, 16, 17, 18)), - ((12,), (53,), (19, 20, 21, 22, 23)), - ((11,), (54,), (14, 15, 16, 17, 18)), - ((55,), (19, 20, 21, 22, 23), (10,)), - ((40,), (38,), (14, 15, 16, 17, 18)), - ((41,), (19, 20, 21, 22, 23), (39,)), - ((28, 29, 30, 31, 32), (12,)), - ((13,), (33, 34, 35, 36, 37)), - ((28, 29, 30, 31, 32), (10,)), - ((11,), (33, 34, 35, 36, 37)), - ((28, 29, 30, 31, 32), (52,), (54,)), - ((53,), (55,), (33, 34, 35, 36, 37)), - ((24,),), - ((25,),), - ((26,),), - ((27,),), + ((5, 6, 7, 8, 9), (19, 20, 21, 22, 23)), + ((0, 1, 2, 3, 4), (14, 15, 16, 17, 18)), + ((5, 6, 7, 8, 9), (32, 28, 29, 30, 31)), + ((0, 1, 2, 3, 4), (33, 34, 35, 36, 37)), + ((42, 43, 44, 45, 46), (14, 15, 16, 17, 18)), + ((19, 20, 21, 22, 23), (47, 48, 49, 50, 51)), + ((32, 28, 29, 30, 31),), + ((33, 34, 35, 36, 37),), ], "center": [ - (0, 1, 2, 3, 4), - (5, 6, 7, 8, 9), - (10,), - (11,), - (12,), - (13,), - (14, 15, 16, 17, 18), - (19, 20, 21, 22, 23), - (24,), - (25,), - (26,), - (27,), - (28, 29, 30, 31, 32), - (33, 34, 35, 36, 37), - (38,), - (39,), - (40,), - (41,), - (42, 43, 44, 45, 46, 57), - (47, 48, 49, 50, 51, 56), - (52,), - (53,), - (54,), - (55,), + (0, 1, 2, 3, 4, 11, 13), + (5, 6, 7, 8, 9, 10, 12), + (14, 15, 16, 17, 18, 24, 26), + (19, 20, 21, 22, 23, 25, 27), + (32, 38, 40, 28, 29, 30, 31), + (33, 34, 35, 36, 37, 39, 41), + (42, 43, 44, 45, 46, 52, 54, 57), + (47, 48, 49, 50, 51, 53, 55, 56), ], "centerf_idx": [ - (0, 1, 2, 3, 4), - (0, 1, 2, 3, 4), - (0,), - (0,), - (0,), - (0,), - (0, 1, 2, 3, 4), - (0, 1, 2, 3, 4), - (0,), - (0,), - (0,), - (0,), - (0, 1, 2, 3, 4), - (0, 1, 2, 3, 4), - (0,), - (0,), - (0,), - (0,), - (0, 1, 2, 3, 4, 12), - (0, 1, 2, 3, 4, 12), - (0,), - (0,), - (0,), - (0,), + (0, 1, 2, 3, 4, 10, 11), + (0, 1, 2, 3, 4, 10, 11), + (0, 1, 2, 3, 4, 10, 11), + (0, 1, 2, 3, 4, 10, 11), + (4, 10, 11, 0, 1, 2, 3), + (0, 1, 2, 3, 4, 10, 11), + (0, 1, 2, 3, 4, 10, 11, 12), + (0, 1, 2, 3, 4, 10, 11, 12), ], "ebe_weight": [ - (1.0, (0, 1, 2, 3, 4)), - (1.0, (0, 1, 2, 3, 4)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0, 1, 2, 3, 4)), - (1.0, (0, 1, 2, 3, 4)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0, 1, 2, 3, 4)), - (1.0, (0, 1, 2, 3, 4)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0, 1, 2, 3, 4, 12)), - (1.0, (0, 1, 2, 3, 4, 12)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), + (1.0, (0, 1, 2, 3, 4, 10, 11)), + (1.0, (0, 1, 2, 3, 4, 10, 11)), + (1.0, (0, 1, 2, 3, 4, 10, 11)), + (1.0, (0, 1, 2, 3, 4, 10, 11)), + (1.0, (4, 10, 11, 0, 1, 2, 3)), + (1.0, (0, 1, 2, 3, 4, 10, 11)), + (1.0, (0, 1, 2, 3, 4, 10, 11, 12)), + (1.0, (0, 1, 2, 3, 4, 10, 11, 12)), ], } @@ -1069,11 +1001,6 @@ def test_graphgen_octane_be3(self): 11, 12, 13, - 14, - 15, - 16, - 17, - 18, 19, 20, 21, @@ -1081,11 +1008,6 @@ def test_graphgen_octane_be3(self): 23, 25, 27, - 33, - 34, - 35, - 36, - 37, ), ( 5, @@ -1107,34 +1029,15 @@ def test_graphgen_octane_be3(self): 16, 17, 18, - 19, - 20, - 21, - 22, - 23, 24, 26, - 28, - 29, - 30, - 31, - 32, ), - (10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 18, 27, 40), - (11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 13, 19, 20, 21, 22, 23, 26, 41), - (12, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 14, 15, 16, 17, 18, 25, 38), - (13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 19, 20, 21, 22, 23, 24, 39), ( 14, 15, 16, 17, 18, - 0, - 1, - 2, - 3, - 4, 5, 6, 7, @@ -1151,11 +1054,6 @@ def test_graphgen_octane_be3(self): 32, 38, 40, - 42, - 43, - 44, - 45, - 46, ), ( 19, @@ -1168,11 +1066,6 @@ def test_graphgen_octane_be3(self): 2, 3, 4, - 5, - 6, - 7, - 8, - 9, 11, 13, 25, @@ -1184,27 +1077,13 @@ def test_graphgen_octane_be3(self): 37, 39, 41, - 47, - 48, - 49, - 50, - 51, ), - (24, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 18, 26, 28, 29, 30, 31, 32, 52), - (25, 0, 1, 2, 3, 4, 12, 19, 20, 21, 22, 23, 27, 33, 34, 35, 36, 37, 53), - (26, 5, 6, 7, 8, 9, 11, 14, 15, 16, 17, 18, 24, 28, 29, 30, 31, 32, 54), - (27, 0, 1, 2, 3, 4, 10, 19, 20, 21, 22, 23, 25, 33, 34, 35, 36, 37, 55), ( 28, 29, 30, 31, 32, - 5, - 6, - 7, - 8, - 9, 14, 15, 16, @@ -1229,11 +1108,6 @@ def test_graphgen_octane_be3(self): 35, 36, 37, - 0, - 1, - 2, - 3, - 4, 19, 20, 21, @@ -1254,68 +1128,36 @@ def test_graphgen_octane_be3(self): ), ], "edge": [ - ((5, 6, 7, 8, 9), (12,), (10,), (11,), (13,), (25,), (27,)), - ((12,), (26,), (10,), (24,), (11,), (13,), (0, 1, 2, 3, 4)), - ((5, 6, 7, 8, 9), (12,), (0, 1, 2, 3, 4), (27,)), - ((5, 6, 7, 8, 9), (13,), (0, 1, 2, 3, 4), (26,)), - ((5, 6, 7, 8, 9), (0, 1, 2, 3, 4), (25,), (10,)), - ((5, 6, 7, 8, 9), (24,), (0, 1, 2, 3, 4), (11,)), - ((5, 6, 7, 8, 9), (12,), (26,), (10,), (24,), (0, 1, 2, 3, 4)), - ((5, 6, 7, 8, 9), (11,), (13,), (0, 1, 2, 3, 4), (25,), (27,)), - ((5, 6, 7, 8, 9), (13,), (26,)), - ((12,), (0, 1, 2, 3, 4), (27,)), - ((5, 6, 7, 8, 9), (24,), (11,)), - ((0, 1, 2, 3, 4), (25,), (10,)), - ((5, 6, 7, 8, 9), (24,), (26,)), - ((0, 1, 2, 3, 4), (25,), (27,)), + ((12,), (27,), (5, 6, 7, 8, 9), (19, 20, 21, 22, 23), (10,), (25,)), + ((14, 15, 16, 17, 18), (11,), (24,), (26,), (13,), (0, 1, 2, 3, 4)), + ((12,), (32, 28, 29, 30, 31), (40,), (5, 6, 7, 8, 9), (10,), (38,)), + ((41,), (11,), (33, 34, 35, 36, 37), (39,), (13,), (0, 1, 2, 3, 4)), + ((24,), (26,), (14, 15, 16, 17, 18)), + ((25,), (19, 20, 21, 22, 23), (27,)), ], "center": [ - (0, 1, 2, 3, 4), - (5, 6, 7, 8, 9), - (10,), - (11,), - (12,), - (13,), - (38, 40, 14, 15, 16, 17, 18), - (39, 41, 19, 20, 21, 22, 23), - (24,), - (25,), - (26,), - (27,), - (32, 42, 43, 44, 45, 46, 52, 54, 57, 28, 29, 30, 31), - (33, 34, 35, 36, 37, 47, 48, 49, 50, 51, 53, 55, 56), + (0, 1, 2, 3, 4, 11, 13), + (5, 6, 7, 8, 9, 10, 12), + (14, 15, 16, 17, 18, 24, 26), + (19, 20, 21, 22, 23, 25, 27), + (32, 38, 40, 42, 43, 44, 45, 46, 52, 54, 57, 28, 29, 30, 31), + (33, 34, 35, 36, 37, 39, 41, 47, 48, 49, 50, 51, 53, 55, 56), ], "centerf_idx": [ - (0, 1, 2, 3, 4), - (0, 1, 2, 3, 4), - (0,), - (0,), - (0,), - (0,), - (0, 1, 2, 3, 4, 24, 25), - (0, 1, 2, 3, 4, 24, 25), - (0,), - (0,), - (0,), - (0,), - (0, 1, 2, 3, 4, 19, 20, 21, 22, 23, 24, 25, 26), - (0, 1, 2, 3, 4, 19, 20, 21, 22, 23, 24, 25, 26), + (0, 1, 2, 3, 4, 11, 13), + (0, 1, 2, 3, 4, 10, 12), + (0, 1, 2, 3, 4, 12, 13), + (0, 1, 2, 3, 4, 12, 13), + (4, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 0, 1, 2, 3), + (0, 1, 2, 3, 4, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21), ], "ebe_weight": [ - (1.0, (0, 1, 2, 3, 4)), - (1.0, (0, 1, 2, 3, 4)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0, 1, 2, 3, 4, 24, 25)), - (1.0, (0, 1, 2, 3, 4, 24, 25)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0,)), - (1.0, (0, 1, 2, 3, 4, 19, 20, 21, 22, 23, 24, 25, 26)), - (1.0, (0, 1, 2, 3, 4, 19, 20, 21, 22, 23, 24, 25, 26)), + (1.0, (0, 1, 2, 3, 4, 11, 13)), + (1.0, (0, 1, 2, 3, 4, 10, 12)), + (1.0, (0, 1, 2, 3, 4, 12, 13)), + (1.0, (0, 1, 2, 3, 4, 12, 13)), + (1.0, (4, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 0, 1, 2, 3)), + (1.0, (0, 1, 2, 3, 4, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)), ], } @@ -1389,9 +1231,9 @@ def run_energies_test( float(E_B), msg=f"{test_name}: BE Correlation Energy (oneshot) for " + frag_type_A - + " does not match" + + " does not match " + frag_type_B - + f" ({E_A} != {E_B}) ", + + f" ({E_A} != {E_B}) \n", delta=delta, ) @@ -1411,7 +1253,7 @@ def run_indices_test( assert fobj.centerf_idx == target["centerf_idx"] assert fobj.ebe_weight == target["ebe_weight"] except AssertionError as e: - print(f"Fragmentation test failed at {test_name}:") + print(f"Fragmentation test failed at {test_name} \n") raise e From 4af68ae5d381a99512a62c575601927a50e525c2 Mon Sep 17 00:00:00 2001 From: Shaun Weatherly Date: Mon, 20 Jan 2025 12:36:14 -0500 Subject: [PATCH 2/3] Modify docstrings. --- src/quemb/molbe/autofrag.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/quemb/molbe/autofrag.py b/src/quemb/molbe/autofrag.py index 2228ad3..36c6529 100644 --- a/src/quemb/molbe/autofrag.py +++ b/src/quemb/molbe/autofrag.py @@ -43,6 +43,12 @@ class FragmentMap: dnames: List of strings giving fragment data names. Useful for bookkeeping and for constructing fragment scratch directories. + fragment_atoms: + List whose entries are tuples containing all atom indices for a fragment. + center_atoms: + List whose entries are tuples giving the center atom indices per fragment. + edge_atoms: + List whose entries are tuples giving the edge atom indices per fragment. adjacency_mat: The adjacency matrix for all sites (atoms) in the system. adjacency_graph: @@ -159,6 +165,11 @@ def graphgen( weights in the fragment adjacency graph. Currently supports "euclidean" (which uses the square of the distance between atoms in real space to determine connectivity within a fragment.) + cutoff: + Atoms with an edge weight beyond `cutoff` will be excluded from the + `shortest_path` calculation. This is crucial when handling very large + systems, where computing the shortest paths from all to all becomes + non-trivial. Defaults to 20.0. Returns ------- From f3042cd2e1ae5556a8887dda91b978bee19b7601 Mon Sep 17 00:00:00 2001 From: Shaun Weatherly Date: Mon, 20 Jan 2025 12:48:09 -0500 Subject: [PATCH 3/3] Expose other keywords and update docstrings --- src/quemb/molbe/fragment.py | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/src/quemb/molbe/fragment.py b/src/quemb/molbe/fragment.py index 7bccd8a..109f07b 100644 --- a/src/quemb/molbe/fragment.py +++ b/src/quemb/molbe/fragment.py @@ -44,6 +44,22 @@ class fragpart: write_geom: bool Whether to write 'fragment.xyz' file which contains all the fragments in cartesian coordinates. + remove_nonunique_frags: + Whether to remove fragments which are strict subsets of another + fragment in the system. True by default. + frag_prefix: + Prefix to be appended to the fragment datanames. Useful for managing + fragment scratch directories. + connectivity: + Keyword string specifying the distance metric to be used for edge + weights in the fragment adjacency graph. Currently supports "euclidean" + (which uses the square of the distance between atoms in real + space to determine connectivity within a fragment.) + cutoff: + Atoms with an edge weight beyond `cutoff` will be excluded from the + `shortest_path` calculation. This is crucial when handling very large + systems, where computing the shortest paths from all to all becomes + non-trivial. Defaults to 20.0. """ def __init__( @@ -55,8 +71,12 @@ def __init__( print_frags=True, write_geom=False, be_type="be2", + frag_prefix="f", + connectivity="euclidean", mol=None, frozen_core=False, + cutoff=20, + remove_nonnunique_frags=True, ): # Initialize class attributes self.mol = mol @@ -70,9 +90,13 @@ def __init__( self.center_idx = [] self.centerf_idx = [] self.be_type = be_type + self.frag_prefix = frag_prefix + self.connectivity = connectivity self.frozen_core = frozen_core self.iao_valence_basis = iao_valence_basis self.valence_only = valence_only + self.cutoff = cutoff + self.remove_nonnunique_frags = remove_nonnunique_frags # Initialize class attributes necessary for mixed-basis BE self.Frag_atom = [] @@ -98,12 +122,13 @@ def __init__( elif frag_type == "graphgen": fragment_map = graphgen( mol=self.mol.copy(), - be_type=be_type, - frozen_core=frozen_core, - remove_nonunique_frags=True, - frag_prefix="f", - connectivity="euclidean", - iao_valence_basis=iao_valence_basis, + be_type=self.be_type, + frozen_core=self.frozen_core, + remove_nonunique_frags=self.remove_nonnunique_frags, + frag_prefix=self.frag_prefix, + connectivity=self.connectivity, + iao_valence_basis=self.iao_valence_basis, + cutoff=self.cutoff, ) self.fsites = fragment_map.fsites