Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing fixes, optimization, and etc. for graphgen #91

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 77 additions & 34 deletions src/quemb/molbe/autofrag.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Author: Oinam Romesh Meitei, Shaun Weatherly

from copy import deepcopy
from typing import Any

import networkx as nx
import numpy as np
from attrs import define
from networkx import single_source_all_shortest_paths # type: ignore[attr-defined]
from networkx import shortest_path # type: ignore[attr-defined]
from numpy.linalg import norm
from pyscf import gto

Expand Down Expand Up @@ -42,6 +43,12 @@ class FragmentMap:
dnames:
List of strings giving fragment data names. Useful for bookkeeping and
for constructing fragment scratch directories.
fragment_atoms:
List whose entries are tuples containing all atom indices for a fragment.
center_atoms:
List whose entries are tuples giving the center atom indices per fragment.
edge_atoms:
List whose entries are tuples giving the edge atom indices per fragment.
adjacency_mat:
The adjacency matrix for all sites (atoms) in the system.
adjacency_graph:
Expand All @@ -56,8 +63,9 @@ class FragmentMap:
ebe_weights: list[tuple]
sites: list[tuple]
dnames: list
center_atoms: list[tuple[str, ...]]
edge_atoms: list[tuple[str, ...]]
fragment_atoms: list[tuple[int, ...]]
center_atoms: list[tuple[int, ...]]
edge_atoms: list[tuple[int, ...]]
adjacency_mat: np.ndarray
adjacency_graph: nx.Graph

Expand All @@ -75,16 +83,30 @@ def remove_nonnunique_frags(self, natm: int) -> None:
such that all fragments are guaranteed to be distinct sets.
"""
for _ in range(0, natm):
subsets = set()
for adx, basa in enumerate(self.fsites):
for bdx, basb in enumerate(self.fsites):
if adx == bdx:
pass
elif set(basb).issubset(set(basa)):
tmp = set(self.center[adx] + self.center[bdx])
self.center[adx] = tuple(tmp)
del self.center[bdx]
del self.fsites[bdx]
del self.fs[bdx]
subsets.add(bdx)
self.center[adx] = tuple(
set(self.center[adx] + deepcopy(self.center[bdx]))
)
self.center_atoms[adx] = tuple(
set(
self.center_atoms[adx]
+ deepcopy(self.center_atoms[bdx])
)
)
if subsets:
sorted_subsets = sorted(subsets, reverse=True)
for bdx in sorted_subsets:
del self.center[bdx]
del self.fsites[bdx]
del self.fs[bdx]
del self.center_atoms[bdx]
del self.fragment_atoms[bdx]

return None

Expand All @@ -104,6 +126,7 @@ def graphgen(
frag_prefix: str = "f",
connectivity: str = "euclidean",
iao_valence_basis: str | None = None,
cutoff: float = 20.0,
) -> FragmentMap:
"""Generate fragments via adjacency graph.

Expand Down Expand Up @@ -142,6 +165,11 @@ def graphgen(
weights in the fragment adjacency graph. Currently supports "euclidean"
(which uses the square of the distance between atoms in real
space to determine connectivity within a fragment.)
cutoff:
Atoms with an edge weight beyond `cutoff` will be excluded from the
`shortest_path` calculation. This is crucial when handling very large
systems, where computing the shortest paths from all to all becomes
non-trivial. Defaults to 20.0.

Returns
-------
Expand Down Expand Up @@ -175,6 +203,7 @@ def graphgen(
ebe_weights=list(tuple()),
sites=list(tuple()),
dnames=list(),
fragment_atoms=list(),
center_atoms=list(),
edge_atoms=list(),
adjacency_mat=np.zeros((natm, natm), np.float64),
Expand Down Expand Up @@ -211,39 +240,49 @@ def graphgen(
** 2
)
fragment_map.adjacency_mat[adx, bdx] = dr
fragment_map.adjacency_graph.add_edge(adx, bdx, weight=dr)
if dr <= cutoff:
fragment_map.adjacency_graph.add_edge(adx, bdx, weight=dr)

# For a given center site (adx), find the set of shortest
# paths to all other sites. The number of nodes visited
# on that path gives the degree of separation of the
# sites.
for adx, map in adx_map.items():
fragment_map.center_atoms.append(tuple())
fsites_temp = fragment_map.sites[adx]
fragment_map.center_atoms.append((adx,))
fragment_map.center.append(deepcopy(fragment_map.sites[adx]))
fsites_temp = deepcopy(fragment_map.sites[adx])
fatoms_temp = [adx]
fs_temp = []
fs_temp.append(fragment_map.sites[adx])
map["shortest_paths"] = dict(
single_source_all_shortest_paths(
fragment_map.adjacency_graph,
source=adx,
weight=lambda a, b, _: (
fragment_map.adjacency_graph[a][b]["weight"]
),
method="dijkstra",
)
)
fs_temp.append(deepcopy(fragment_map.sites[adx]))

for bdx, _ in adx_map.items():
if fragment_map.adjacency_graph.has_edge(adx, bdx):
map["shortest_paths"].update(
{
bdx: shortest_path(
fragment_map.adjacency_graph,
source=adx,
target=bdx,
weight=lambda a, b, _: (
fragment_map.adjacency_graph[a][b]["weight"]
),
method="dijkstra",
)
}
)

# If the degree of separation is smaller than the *n*
# in your fragment type, BE*n*, then that site is appended to
# the set of fragment sites for adx.
for bdx, path in map["shortest_paths"].items():
if 0 < (len(path[0]) - 1) < fragment_type_order:
fsites_temp = tuple(fsites_temp + fragment_map.sites[bdx])
fs_temp.append(tuple(fragment_map.sites[bdx]))
if 0 < (len(path) - 1) < fragment_type_order:
fsites_temp = fsites_temp + deepcopy(fragment_map.sites[bdx])
fs_temp.append(deepcopy(fragment_map.sites[bdx]))
fatoms_temp.append(bdx)

fragment_map.fsites.append(tuple(fsites_temp))
fragment_map.fs.append(tuple(fs_temp))
fragment_map.center.append(tuple(fragment_map.sites[adx]))
fragment_map.fragment_atoms.append(tuple(fatoms_temp))

elif connectivity.lower() in ["resistance_distance", "resistance"]:
raise NotImplementedError("Work in progress...")
Expand All @@ -260,22 +299,26 @@ def graphgen(
# Define the 'edges' for fragment A as the intersect of its sites
# with the set of all center sites outside of A:
for adx, fs in enumerate(fragment_map.fs):
edge: set[tuple] = set()
edge_temp: set[tuple] = set()
eatoms_temp: set = set()
for bdx, center in enumerate(fragment_map.center):
if adx == bdx:
pass
else:
overlap = set(fs).intersection(set((center,)))
if overlap:
edge = edge.union(overlap)
fragment_map.edge.append(tuple(edge))
for f in fs:
overlap = set(f).intersection(set(center))
if overlap:
f_temp = set(fragment_map.fragment_atoms[adx])
c_temp = set(fragment_map.center_atoms[bdx])
edge_temp.add(tuple(overlap))
eatoms_temp.add(i for i in f_temp.intersection(c_temp))
fragment_map.edge.append(tuple(edge_temp))
fragment_map.edge_atoms.append(tuple(eatoms_temp))

# Update relative center site indices (centerf_idx) and weights
# for center site contributions to the energy (ebe_weights):
for adx, center in enumerate(fragment_map.center):
centerf_idx = tuple(
set([fragment_map.fsites[adx].index(cdx) for cdx in center])
)
centerf_idx = tuple([fragment_map.fsites[adx].index(cdx) for cdx in center])
ebe_weight = (1.0, tuple(centerf_idx))
fragment_map.centerf_idx.append(centerf_idx)
fragment_map.ebe_weights.append(ebe_weight)
Expand Down
40 changes: 33 additions & 7 deletions src/quemb/molbe/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ class fragpart:
write_geom: bool
Whether to write 'fragment.xyz' file which contains all the fragments
in cartesian coordinates.
remove_nonunique_frags:
Whether to remove fragments which are strict subsets of another
fragment in the system. True by default.
frag_prefix:
Prefix to be appended to the fragment datanames. Useful for managing
fragment scratch directories.
connectivity:
Keyword string specifying the distance metric to be used for edge
weights in the fragment adjacency graph. Currently supports "euclidean"
(which uses the square of the distance between atoms in real
space to determine connectivity within a fragment.)
cutoff:
Atoms with an edge weight beyond `cutoff` will be excluded from the
`shortest_path` calculation. This is crucial when handling very large
systems, where computing the shortest paths from all to all becomes
non-trivial. Defaults to 20.0.
"""

def __init__(
Expand All @@ -55,8 +71,12 @@ def __init__(
print_frags=True,
write_geom=False,
be_type="be2",
frag_prefix="f",
connectivity="euclidean",
mol=None,
frozen_core=False,
cutoff=20,
remove_nonnunique_frags=True,
):
# Initialize class attributes
self.mol = mol
Expand All @@ -70,9 +90,13 @@ def __init__(
self.center_idx = []
self.centerf_idx = []
self.be_type = be_type
self.frag_prefix = frag_prefix
self.connectivity = connectivity
self.frozen_core = frozen_core
self.iao_valence_basis = iao_valence_basis
self.valence_only = valence_only
self.cutoff = cutoff
self.remove_nonnunique_frags = remove_nonnunique_frags

# Initialize class attributes necessary for mixed-basis BE
self.Frag_atom = []
Expand All @@ -98,18 +122,20 @@ def __init__(
elif frag_type == "graphgen":
fragment_map = graphgen(
mol=self.mol.copy(),
be_type=be_type,
frozen_core=frozen_core,
remove_nonunique_frags=True,
frag_prefix="f",
connectivity="euclidean",
iao_valence_basis=iao_valence_basis,
be_type=self.be_type,
frozen_core=self.frozen_core,
remove_nonunique_frags=self.remove_nonnunique_frags,
frag_prefix=self.frag_prefix,
connectivity=self.connectivity,
iao_valence_basis=self.iao_valence_basis,
cutoff=self.cutoff,
)

self.fsites = fragment_map.fsites
self.edge = fragment_map.edge
self.center = fragment_map.center
# self.edge_idx = fragment_map["edge"]
self.Frag_atom = fragment_map.fragment_atoms
self.center_atom = fragment_map.center_atoms
self.centerf_idx = fragment_map.centerf_idx
self.ebe_weight = fragment_map.ebe_weights
self.Nfrag = len(self.fsites)
Expand Down
Loading
Loading