Skip to content

Commit

Permalink
Indexing fixes, optimization, and etc. for graphgen (#91)
Browse files Browse the repository at this point in the history
* Indexing fixes, optimization, and etc.

* Modify docstrings.

* Expose other keywords and update docstrings

* A truly inconsequential change to `euclidean_norm`.

* Update typing to `Sequence`

* Update `euclidean_distance`

* Formatting

* Fixed typing for `Sequence` in `FragmentMap`

* Indexing fixes, optimization, and etc.

* Modify docstrings.

* Expose other keywords and update docstrings

* A truly inconsequential change to `euclidean_norm`.

* Update typing to `Sequence`

* Update `euclidean_distance`

* Formatting

* Fixed typing for `Sequence` in `FragmentMap`

* Remove `# type ignore`

* Fix `eatoms_temp`

* Edit `centerf_idx`

* Remove commented code

* Formatting

* Fix set reference

* Fix `edge_atoms`

* Fix typing for `FragmentMap.edge_atoms`
  • Loading branch information
ShaunWeatherly authored Jan 22, 2025
1 parent b2d0432 commit b747e0c
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 300 deletions.
158 changes: 103 additions & 55 deletions src/quemb/molbe/autofrag.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Author: Oinam Romesh Meitei, Shaun Weatherly

from typing import Any
from copy import deepcopy
from typing import Sequence

import networkx as nx
import numpy as np
from attrs import define
from networkx import single_source_all_shortest_paths # type: ignore[attr-defined]
from networkx import shortest_path
from numpy.linalg import norm
from pyscf import gto

from quemb.molbe.helper import get_core
from quemb.shared.helper import unused
from quemb.shared.typing import Vector


@define
Expand All @@ -20,44 +22,52 @@ class FragmentMap:
Parameters
----------
fsites:
List whose entries are tuples containing all AO indices for a fragment.
List whose entries are sequences (tuple or list) containing
all AO indices for a fragment.
fs:
List whose entries are tuples of tuples, containing AO indices per atom
List whose entries are sequences of sequences, containing AO indices per atom
per fragment.
edge:
List whose entries are tuples of tuples, containing edge AO
List whose entries are sequences of sequences, containing edge AO
indices per atom (inner tuple) per fragment (outer tuple).
center:
List whose entries are tuples of tuples, containing all fragment AO
List whose entries are sequences of sequences, containing all fragment AO
indices per atom (inner tuple) and per fragment (outer tuple).
centerf_idx:
List whose entries are tuples containing the relative index of all
List whose entries are sequences containing the relative index of all
center sites within a fragment (ie, with respect to fsites).
ebe_weights:
Weights determining the energy contributions from each center site
(ie, with respect to centerf_idx).
sites:
List whose entries are tuples containing all AO indices per atom
List whose entries are sequences containing all AO indices per atom
(excluding frozen core indices, if applicable).
dnames:
List of strings giving fragment data names. Useful for bookkeeping and
for constructing fragment scratch directories.
fragment_atoms:
List whose entries are sequences containing all atom indices for a fragment.
center_atoms:
List whose entries are sequences giving the center atom indices per fragment.
edge_atoms:
List whose entries are sequences giving the edge atom indices per fragment.
adjacency_mat:
The adjacency matrix for all sites (atoms) in the system.
adjacency_graph:
The adjacency graph corresponding to `adjacency_mat`.
"""

fsites: list[tuple[int, ...]]
fs: list[tuple[tuple[int, ...], ...]]
edge: list[tuple[tuple[int, ...], ...]]
center: list[tuple[int, ...]]
centerf_idx: list[tuple[int, ...]]
ebe_weights: list[tuple]
sites: list[tuple]
dnames: list
center_atoms: list[tuple[str, ...]]
edge_atoms: list[tuple[str, ...]]
fsites: list[Sequence[int]]
fs: list[Sequence[Sequence[int]]]
edge: list[Sequence[Sequence[int]]]
center: list[Sequence[int]]
centerf_idx: list[Sequence[int]]
ebe_weights: list[Sequence]
sites: list[Sequence]
dnames: list[str]
fragment_atoms: list[Sequence[int]]
center_atoms: list[Sequence[int]]
edge_atoms: list[Sequence[int]]
adjacency_mat: np.ndarray
adjacency_graph: nx.Graph

Expand All @@ -75,25 +85,42 @@ def remove_nonnunique_frags(self, natm: int) -> None:
such that all fragments are guaranteed to be distinct sets.
"""
for _ in range(0, natm):
subsets = set()
for adx, basa in enumerate(self.fsites):
for bdx, basb in enumerate(self.fsites):
if adx == bdx:
pass
elif set(basb).issubset(set(basa)):
tmp = set(self.center[adx] + self.center[bdx])
self.center[adx] = tuple(tmp)
del self.center[bdx]
del self.fsites[bdx]
del self.fs[bdx]
subsets.add(bdx)
self.center[adx] = tuple(
set(
list(self.center[adx])
+ list(deepcopy(self.center[bdx]))
)
)
self.center_atoms[adx] = tuple(
set(
list(self.center_atoms[adx])
+ list(deepcopy(self.center_atoms[bdx]))
)
)
if subsets:
sorted_subsets = sorted(subsets, reverse=True)
for bdx in sorted_subsets:
del self.center[bdx]
del self.fsites[bdx]
del self.fs[bdx]
del self.center_atoms[bdx]
del self.fragment_atoms[bdx]

return None


def euclidean_norm(
i_coord: np.ndarray,
j_coord: np.ndarray,
) -> np.floating[Any]:
return norm(np.asarray(i_coord - j_coord))
def euclidean_distance(
i_coord: Vector,
j_coord: Vector,
) -> np.floating:
return norm(i_coord - j_coord)


def graphgen(
Expand All @@ -104,6 +131,7 @@ def graphgen(
frag_prefix: str = "f",
connectivity: str = "euclidean",
iao_valence_basis: str | None = None,
cutoff: float = 20.0,
) -> FragmentMap:
"""Generate fragments via adjacency graph.
Expand Down Expand Up @@ -142,6 +170,11 @@ def graphgen(
weights in the fragment adjacency graph. Currently supports "euclidean"
(which uses the square of the distance between atoms in real
space to determine connectivity within a fragment.)
cutoff:
Atoms with an edge weight beyond `cutoff` will be excluded from the
`shortest_path` calculation. This is crucial when handling very large
systems, where computing the shortest paths from all to all becomes
non-trivial. Defaults to 20.0.
Returns
-------
Expand Down Expand Up @@ -175,6 +208,7 @@ def graphgen(
ebe_weights=list(tuple()),
sites=list(tuple()),
dnames=list(),
fragment_atoms=list(),
center_atoms=list(),
edge_atoms=list(),
adjacency_mat=np.zeros((natm, natm), np.float64),
Expand Down Expand Up @@ -204,46 +238,56 @@ def graphgen(
for adx in range(natm):
for bdx in range(adx + 1, natm):
dr = (
euclidean_norm(
euclidean_distance(
adx_map[adx]["coord"],
adx_map[bdx]["coord"],
)
** 2
)
fragment_map.adjacency_mat[adx, bdx] = dr
fragment_map.adjacency_graph.add_edge(adx, bdx, weight=dr)
if dr <= cutoff:
fragment_map.adjacency_graph.add_edge(adx, bdx, weight=dr)

# For a given center site (adx), find the set of shortest
# paths to all other sites. The number of nodes visited
# on that path gives the degree of separation of the
# sites.
for adx, map in adx_map.items():
fragment_map.center_atoms.append(tuple())
fsites_temp = fragment_map.sites[adx]
fragment_map.center_atoms.append((adx,))
fragment_map.center.append(deepcopy(fragment_map.sites[adx]))
fsites_temp = deepcopy(list(fragment_map.sites[adx]))
fatoms_temp = [adx]
fs_temp = []
fs_temp.append(fragment_map.sites[adx])
map["shortest_paths"] = dict(
single_source_all_shortest_paths(
fragment_map.adjacency_graph,
source=adx,
weight=lambda a, b, _: (
fragment_map.adjacency_graph[a][b]["weight"]
),
method="dijkstra",
)
)
fs_temp.append(deepcopy(fragment_map.sites[adx]))

for bdx, _ in adx_map.items():
if fragment_map.adjacency_graph.has_edge(adx, bdx):
map["shortest_paths"].update(
{
bdx: shortest_path(
fragment_map.adjacency_graph,
source=adx,
target=bdx,
weight=lambda a, b, _: (
fragment_map.adjacency_graph[a][b]["weight"]
),
method="dijkstra",
)
}
)

# If the degree of separation is smaller than the *n*
# in your fragment type, BE*n*, then that site is appended to
# the set of fragment sites for adx.
for bdx, path in map["shortest_paths"].items():
if 0 < (len(path[0]) - 1) < fragment_type_order:
fsites_temp = tuple(fsites_temp + fragment_map.sites[bdx])
fs_temp.append(tuple(fragment_map.sites[bdx]))
if 0 < (len(path) - 1) < fragment_type_order:
fsites_temp = fsites_temp + deepcopy(list(fragment_map.sites[bdx]))
fs_temp.append(deepcopy(fragment_map.sites[bdx]))
fatoms_temp.append(bdx)

fragment_map.fsites.append(tuple(fsites_temp))
fragment_map.fs.append(tuple(fs_temp))
fragment_map.center.append(tuple(fragment_map.sites[adx]))
fragment_map.fragment_atoms.append(tuple(fatoms_temp))

elif connectivity.lower() in ["resistance_distance", "resistance"]:
raise NotImplementedError("Work in progress...")
Expand All @@ -260,22 +304,26 @@ def graphgen(
# Define the 'edges' for fragment A as the intersect of its sites
# with the set of all center sites outside of A:
for adx, fs in enumerate(fragment_map.fs):
edge: set[tuple] = set()
edge_temp: set[tuple] = set()
eatoms_temp: set[tuple[int, ...]] = set()
for bdx, center in enumerate(fragment_map.center):
if adx == bdx:
pass
else:
overlap = set(fs).intersection(set((center,)))
if overlap:
edge = edge.union(overlap)
fragment_map.edge.append(tuple(edge))
for f in fs:
overlap = set(f).intersection(set(center))
if overlap:
f_temp = set(fragment_map.fragment_atoms[adx])
c_temp = set(fragment_map.center_atoms[bdx])
edge_temp.add(tuple(overlap))
eatoms_temp.add(tuple(i for i in f_temp.intersection(c_temp)))
fragment_map.edge.append(tuple(edge_temp))
fragment_map.edge_atoms.extend(tuple(eatoms_temp))

# Update relative center site indices (centerf_idx) and weights
# for center site contributions to the energy (ebe_weights):
for adx, center in enumerate(fragment_map.center):
centerf_idx = tuple(
set([fragment_map.fsites[adx].index(cdx) for cdx in center])
)
centerf_idx = tuple(fragment_map.fsites[adx].index(cdx) for cdx in center)
ebe_weight = (1.0, tuple(centerf_idx))
fragment_map.centerf_idx.append(centerf_idx)
fragment_map.ebe_weights.append(ebe_weight)
Expand Down
45 changes: 35 additions & 10 deletions src/quemb/molbe/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ class fragpart:
Parameters
----------
frag_type : str
Name of fragmentation function. 'autogen', 'hchain_simple', and 'chain'
are supported. Defaults to 'autogen' For systems with only hydrogen,
use 'chain'; everything else should use 'autogen'
Name of fragmentation function. 'autogen', 'graphgen', 'hchain_simple',
and 'chain' are supported. Defaults to 'autogen'.
be_type : str
Specifies order of bootsrap calculation in the atom-based fragmentation.
'be1', 'be2', 'be3', & 'be4' are supported.
Expand All @@ -44,6 +43,22 @@ class fragpart:
write_geom: bool
Whether to write 'fragment.xyz' file which contains all the fragments
in cartesian coordinates.
remove_nonunique_frags:
Whether to remove fragments which are strict subsets of another
fragment in the system. True by default.
frag_prefix:
Prefix to be appended to the fragment datanames. Useful for managing
fragment scratch directories.
connectivity:
Keyword string specifying the distance metric to be used for edge
weights in the fragment adjacency graph. Currently supports "euclidean"
(which uses the square of the distance between atoms in real
space to determine connectivity within a fragment.)
cutoff:
Atoms with an edge weight beyond `cutoff` will be excluded from the
`shortest_path` calculation. This is crucial when handling very large
systems, where computing the shortest paths from all to all becomes
non-trivial. Defaults to 20.0.
"""

def __init__(
Expand All @@ -55,8 +70,12 @@ def __init__(
print_frags=True,
write_geom=False,
be_type="be2",
frag_prefix="f",
connectivity="euclidean",
mol=None,
frozen_core=False,
cutoff=20,
remove_nonnunique_frags=True,
):
# Initialize class attributes
self.mol = mol
Expand All @@ -70,9 +89,13 @@ def __init__(
self.center_idx = []
self.centerf_idx = []
self.be_type = be_type
self.frag_prefix = frag_prefix
self.connectivity = connectivity
self.frozen_core = frozen_core
self.iao_valence_basis = iao_valence_basis
self.valence_only = valence_only
self.cutoff = cutoff
self.remove_nonnunique_frags = remove_nonnunique_frags

# Initialize class attributes necessary for mixed-basis BE
self.Frag_atom = []
Expand All @@ -98,18 +121,20 @@ def __init__(
elif frag_type == "graphgen":
fragment_map = graphgen(
mol=self.mol.copy(),
be_type=be_type,
frozen_core=frozen_core,
remove_nonunique_frags=True,
frag_prefix="f",
connectivity="euclidean",
iao_valence_basis=iao_valence_basis,
be_type=self.be_type,
frozen_core=self.frozen_core,
remove_nonunique_frags=self.remove_nonnunique_frags,
frag_prefix=self.frag_prefix,
connectivity=self.connectivity,
iao_valence_basis=self.iao_valence_basis,
cutoff=self.cutoff,
)

self.fsites = fragment_map.fsites
self.edge = fragment_map.edge
self.center = fragment_map.center
# self.edge_idx = fragment_map["edge"]
self.Frag_atom = fragment_map.fragment_atoms
self.center_atom = fragment_map.center_atoms
self.centerf_idx = fragment_map.centerf_idx
self.ebe_weight = fragment_map.ebe_weights
self.Nfrag = len(self.fsites)
Expand Down
Loading

0 comments on commit b747e0c

Please sign in to comment.